onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (322) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6121 -0
  4. onnxruntime/__init__.py +418 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +175 -0
  7. onnxruntime/backend/backend_rep.py +52 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/build_and_package_info.py +2 -0
  13. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  14. onnxruntime/capi/onnxruntime.dll +0 -0
  15. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  16. onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
  17. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  18. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  19. onnxruntime/capi/onnxruntime_validation.py +154 -0
  20. onnxruntime/capi/version_info.py +2 -0
  21. onnxruntime/datasets/__init__.py +18 -0
  22. onnxruntime/datasets/logreg_iris.onnx +0 -0
  23. onnxruntime/datasets/mul_1.onnx +0 -0
  24. onnxruntime/datasets/sigmoid.onnx +13 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  27. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  28. onnxruntime/quantization/__init__.py +19 -0
  29. onnxruntime/quantization/base_quantizer.py +529 -0
  30. onnxruntime/quantization/calibrate.py +1267 -0
  31. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  32. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  33. onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
  34. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  35. onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
  36. onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
  37. onnxruntime/quantization/fusions/__init__.py +4 -0
  38. onnxruntime/quantization/fusions/fusion.py +311 -0
  39. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  40. onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
  41. onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
  42. onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
  43. onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
  44. onnxruntime/quantization/neural_compressor/__init__.py +1 -0
  45. onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
  46. onnxruntime/quantization/neural_compressor/util.py +80 -0
  47. onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
  48. onnxruntime/quantization/onnx_model.py +600 -0
  49. onnxruntime/quantization/onnx_quantizer.py +1163 -0
  50. onnxruntime/quantization/operators/__init__.py +2 -0
  51. onnxruntime/quantization/operators/activation.py +119 -0
  52. onnxruntime/quantization/operators/argmax.py +18 -0
  53. onnxruntime/quantization/operators/attention.py +73 -0
  54. onnxruntime/quantization/operators/base_operator.py +26 -0
  55. onnxruntime/quantization/operators/binary_op.py +72 -0
  56. onnxruntime/quantization/operators/concat.py +62 -0
  57. onnxruntime/quantization/operators/conv.py +260 -0
  58. onnxruntime/quantization/operators/direct_q8.py +78 -0
  59. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  60. onnxruntime/quantization/operators/gather.py +64 -0
  61. onnxruntime/quantization/operators/gavgpool.py +62 -0
  62. onnxruntime/quantization/operators/gemm.py +172 -0
  63. onnxruntime/quantization/operators/lstm.py +121 -0
  64. onnxruntime/quantization/operators/matmul.py +231 -0
  65. onnxruntime/quantization/operators/maxpool.py +34 -0
  66. onnxruntime/quantization/operators/norm.py +40 -0
  67. onnxruntime/quantization/operators/pad.py +172 -0
  68. onnxruntime/quantization/operators/pooling.py +67 -0
  69. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  70. onnxruntime/quantization/operators/resize.py +34 -0
  71. onnxruntime/quantization/operators/softmax.py +74 -0
  72. onnxruntime/quantization/operators/split.py +63 -0
  73. onnxruntime/quantization/operators/where.py +87 -0
  74. onnxruntime/quantization/preprocess.py +141 -0
  75. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  76. onnxruntime/quantization/qdq_quantizer.py +1477 -0
  77. onnxruntime/quantization/quant_utils.py +1051 -0
  78. onnxruntime/quantization/quantize.py +953 -0
  79. onnxruntime/quantization/registry.py +110 -0
  80. onnxruntime/quantization/shape_inference.py +204 -0
  81. onnxruntime/quantization/static_quantize_runner.py +256 -0
  82. onnxruntime/quantization/tensor_quant_overrides.py +520 -0
  83. onnxruntime/tools/__init__.py +10 -0
  84. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  85. onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
  86. onnxruntime/tools/file_utils.py +47 -0
  87. onnxruntime/tools/logger.py +11 -0
  88. onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
  89. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  90. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
  91. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  92. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  93. onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
  94. onnxruntime/tools/offline_tuning.py +169 -0
  95. onnxruntime/tools/onnx_model_utils.py +416 -0
  96. onnxruntime/tools/onnx_randomizer.py +85 -0
  97. onnxruntime/tools/onnxruntime_test.py +164 -0
  98. onnxruntime/tools/optimize_onnx_model.py +56 -0
  99. onnxruntime/tools/ort_format_model/__init__.py +27 -0
  100. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  140. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  141. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  142. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  143. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  144. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  145. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  146. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  147. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  148. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  149. onnxruntime/tools/ort_format_model/types.py +85 -0
  150. onnxruntime/tools/ort_format_model/utils.py +61 -0
  151. onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
  152. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  153. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  154. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  155. onnxruntime/tools/qnn/add_trans_cast.py +292 -0
  156. onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
  157. onnxruntime/tools/qnn/preprocess.py +165 -0
  158. onnxruntime/tools/reduced_build_config_parser.py +203 -0
  159. onnxruntime/tools/remove_initializer_from_input.py +37 -0
  160. onnxruntime/tools/symbolic_shape_infer.py +3094 -0
  161. onnxruntime/tools/update_onnx_opset.py +31 -0
  162. onnxruntime/transformers/__init__.py +8 -0
  163. onnxruntime/transformers/affinity_helper.py +40 -0
  164. onnxruntime/transformers/benchmark.py +942 -0
  165. onnxruntime/transformers/benchmark_helper.py +643 -0
  166. onnxruntime/transformers/bert_perf_test.py +629 -0
  167. onnxruntime/transformers/bert_test_data.py +641 -0
  168. onnxruntime/transformers/compare_bert_results.py +256 -0
  169. onnxruntime/transformers/constants.py +47 -0
  170. onnxruntime/transformers/convert_generation.py +3605 -0
  171. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  172. onnxruntime/transformers/convert_to_packing_mode.py +385 -0
  173. onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
  174. onnxruntime/transformers/float16.py +501 -0
  175. onnxruntime/transformers/fusion_attention.py +1189 -0
  176. onnxruntime/transformers/fusion_attention_clip.py +340 -0
  177. onnxruntime/transformers/fusion_attention_sam2.py +533 -0
  178. onnxruntime/transformers/fusion_attention_unet.py +1307 -0
  179. onnxruntime/transformers/fusion_attention_vae.py +300 -0
  180. onnxruntime/transformers/fusion_bart_attention.py +435 -0
  181. onnxruntime/transformers/fusion_base.py +141 -0
  182. onnxruntime/transformers/fusion_bias_add.py +57 -0
  183. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  184. onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
  185. onnxruntime/transformers/fusion_conformer_attention.py +222 -0
  186. onnxruntime/transformers/fusion_constant_fold.py +144 -0
  187. onnxruntime/transformers/fusion_embedlayer.py +810 -0
  188. onnxruntime/transformers/fusion_fastgelu.py +492 -0
  189. onnxruntime/transformers/fusion_gelu.py +258 -0
  190. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  191. onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
  192. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  193. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  194. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  195. onnxruntime/transformers/fusion_group_norm.py +180 -0
  196. onnxruntime/transformers/fusion_layernorm.py +489 -0
  197. onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
  198. onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
  199. onnxruntime/transformers/fusion_options.py +340 -0
  200. onnxruntime/transformers/fusion_qordered_attention.py +420 -0
  201. onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
  202. onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
  203. onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
  204. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  205. onnxruntime/transformers/fusion_reshape.py +173 -0
  206. onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
  207. onnxruntime/transformers/fusion_shape.py +109 -0
  208. onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
  209. onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
  210. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  211. onnxruntime/transformers/fusion_transpose.py +167 -0
  212. onnxruntime/transformers/fusion_utils.py +321 -0
  213. onnxruntime/transformers/huggingface_models.py +74 -0
  214. onnxruntime/transformers/import_utils.py +20 -0
  215. onnxruntime/transformers/io_binding_helper.py +487 -0
  216. onnxruntime/transformers/large_model_exporter.py +395 -0
  217. onnxruntime/transformers/machine_info.py +230 -0
  218. onnxruntime/transformers/metrics.py +163 -0
  219. onnxruntime/transformers/models/bart/__init__.py +12 -0
  220. onnxruntime/transformers/models/bart/export.py +98 -0
  221. onnxruntime/transformers/models/bert/__init__.py +12 -0
  222. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  223. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  224. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  225. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
  226. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
  227. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  228. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  229. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  230. onnxruntime/transformers/models/llama/__init__.py +12 -0
  231. onnxruntime/transformers/models/llama/benchmark.py +700 -0
  232. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  233. onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
  234. onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
  235. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  236. onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
  237. onnxruntime/transformers/models/llama/llama_parity.py +343 -0
  238. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  239. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  240. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  241. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  242. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  243. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  244. onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
  245. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  246. onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
  247. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  248. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  249. onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
  250. onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
  251. onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
  252. onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
  253. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  254. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  255. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  256. onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
  257. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
  258. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  259. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  260. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
  261. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  262. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
  263. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
  264. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  265. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
  266. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
  267. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
  268. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
  269. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  270. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  271. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  272. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
  273. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  274. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  275. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  276. onnxruntime/transformers/models/t5/__init__.py +12 -0
  277. onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
  278. onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
  279. onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
  280. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
  281. onnxruntime/transformers/models/t5/t5_helper.py +302 -0
  282. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  283. onnxruntime/transformers/models/whisper/benchmark.py +585 -0
  284. onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
  285. onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
  286. onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
  287. onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
  288. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  289. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
  290. onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
  291. onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
  292. onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
  293. onnxruntime/transformers/onnx_exporter.py +719 -0
  294. onnxruntime/transformers/onnx_model.py +1636 -0
  295. onnxruntime/transformers/onnx_model_bart.py +141 -0
  296. onnxruntime/transformers/onnx_model_bert.py +488 -0
  297. onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
  298. onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
  299. onnxruntime/transformers/onnx_model_clip.py +42 -0
  300. onnxruntime/transformers/onnx_model_conformer.py +32 -0
  301. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  302. onnxruntime/transformers/onnx_model_mmdit.py +112 -0
  303. onnxruntime/transformers/onnx_model_phi.py +929 -0
  304. onnxruntime/transformers/onnx_model_sam2.py +137 -0
  305. onnxruntime/transformers/onnx_model_t5.py +985 -0
  306. onnxruntime/transformers/onnx_model_tnlr.py +226 -0
  307. onnxruntime/transformers/onnx_model_unet.py +258 -0
  308. onnxruntime/transformers/onnx_model_vae.py +42 -0
  309. onnxruntime/transformers/onnx_utils.py +55 -0
  310. onnxruntime/transformers/optimizer.py +620 -0
  311. onnxruntime/transformers/past_helper.py +149 -0
  312. onnxruntime/transformers/profile_result_processor.py +358 -0
  313. onnxruntime/transformers/profiler.py +434 -0
  314. onnxruntime/transformers/quantize_helper.py +76 -0
  315. onnxruntime/transformers/shape_infer_helper.py +121 -0
  316. onnxruntime/transformers/shape_optimizer.py +400 -0
  317. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  318. onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
  319. onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
  320. onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
  321. onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
  322. onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,414 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+
6
+ import time
7
+
8
+ import numpy as np
9
+ import torch
10
+ from transformers import AutoTokenizer
11
+
12
+ import onnxruntime as ort
13
+
14
+ pt_to_np = {
15
+ "torch.int32": np.int32,
16
+ "torch.int64": np.int64,
17
+ "torch.float32": np.float32,
18
+ "torch.float16": np.float16,
19
+ }
20
+
21
+
22
+ def cuda_memcpy(dst, src):
23
+ from cuda import cudart # noqa: PLC0415
24
+
25
+ cudart.cudaMemcpy(
26
+ dst.data_ptr(),
27
+ src.data_ptr(),
28
+ src.element_size() * src.nelement(),
29
+ cudart.cudaMemcpyKind.cudaMemcpyDeviceToDevice,
30
+ )
31
+
32
+
33
+ class ORTGenerator:
34
+ def __init__(self, decoder_path):
35
+ self.onnx_decoder_path = decoder_path
36
+ self.num_heads = 32
37
+ self.head_size = 80
38
+ self.num_layers = 32
39
+ self.max_sequence_length = 2048
40
+ self.device_id = 0
41
+ self.use_cuda_graph = False
42
+ self.use_traced_inputs = False
43
+ self.static_inputs_map = {}
44
+
45
+ def append_static_inputs(self, batch_size):
46
+ # Only use this function with GQA and with use_cuda_graph=True
47
+ if batch_size in self.static_inputs_map:
48
+ return
49
+
50
+ cpu_device = torch.device("cpu")
51
+ cuda_device = torch.device("cuda", self.device_id)
52
+
53
+ static_io = {}
54
+ static_io["input_ids"] = torch.zeros((batch_size, 1), dtype=torch.int32, device=cuda_device)
55
+ static_io["step"] = torch.tensor([0], dtype=torch.int64, device=cuda_device)
56
+ static_io["seqlens_k"] = torch.tensor(batch_size * [0], dtype=torch.int32, device=cuda_device)
57
+ static_io["total_sequence_length"] = torch.tensor([0], dtype=torch.int32, device=cpu_device)
58
+
59
+ cache_shape = (batch_size, self.num_heads, self.max_sequence_length, self.head_size)
60
+ for i in range(self.num_layers):
61
+ cache = torch.zeros(cache_shape, device=cuda_device, dtype=torch.float16)
62
+ static_io.update({f"past_key_{i}": cache.contiguous(), f"past_value_{i}": cache.clone().contiguous()})
63
+
64
+ static_io["logits"] = torch.zeros((batch_size, 1, 51200), dtype=torch.float16, device=cuda_device)
65
+
66
+ self.static_inputs_map[batch_size] = static_io
67
+
68
+ def get_initial_inputs_and_outputs(self, encodings_dict):
69
+ self.torch_dtype = torch.float16 if self.use_fp16 else torch.float32
70
+
71
+ input_ids = torch.tensor(encodings_dict["input_ids"], device=self.device, dtype=torch.int32)
72
+ attention_mask = torch.tensor(encodings_dict["attention_mask"], device=self.device, dtype=torch.int32)
73
+
74
+ batch_size, sequence_length = input_ids.shape
75
+
76
+ self.use_traced_inputs = (
77
+ self.use_cuda_graph
78
+ and (batch_size in self.static_inputs_map)
79
+ and self.use_buffer_share
80
+ and not self.packed_kv
81
+ )
82
+
83
+ step = (
84
+ torch.tensor([0], device=self.device, dtype=torch.int64)
85
+ if not self.use_traced_inputs
86
+ else self.static_inputs_map[batch_size]["step"]
87
+ )
88
+
89
+ seqlens_k = (
90
+ torch.tensor(batch_size * [0], device=self.device, dtype=torch.int32)
91
+ if not self.use_traced_inputs
92
+ else self.static_inputs_map[batch_size]["seqlens_k"]
93
+ )
94
+ cuda_memcpy(seqlens_k, attention_mask.sum(1).sub(1).to(torch.int32))
95
+
96
+ total_seq_length = (
97
+ torch.tensor([0], device=torch.device("cpu"), dtype=torch.int32)
98
+ if not self.use_traced_inputs
99
+ else self.static_inputs_map[batch_size]["total_sequence_length"]
100
+ )
101
+ total_seq_length[0] = sequence_length
102
+
103
+ inputs = {
104
+ "input_ids": input_ids.contiguous(),
105
+ "attention_mask": attention_mask.contiguous(),
106
+ }
107
+
108
+ if self.use_step:
109
+ inputs["step"] = step.contiguous()
110
+
111
+ if self.use_cuda_graph:
112
+ inputs["seqlens_k"] = seqlens_k.contiguous()
113
+ inputs["total_sequence_length"] = total_seq_length.contiguous()
114
+ del inputs["attention_mask"]
115
+
116
+ past_seq_length = self.max_sequence_length if self.use_buffer_share else 0
117
+ past_shape = (
118
+ (2, batch_size, self.num_heads, past_seq_length, self.head_size)
119
+ if self.packed_kv
120
+ else (batch_size, self.num_heads, past_seq_length, self.head_size)
121
+ )
122
+
123
+ if not self.use_traced_inputs:
124
+ for i in range(self.num_layers):
125
+ past = torch.zeros(past_shape, device=self.device, dtype=self.torch_dtype)
126
+ (
127
+ inputs.update({f"past_key_{i}": past.contiguous(), f"past_value_{i}": past.clone().contiguous()})
128
+ if not self.packed_kv
129
+ else inputs.update({f"past_{i}": past.contiguous()})
130
+ )
131
+ else:
132
+ for i in range(self.num_layers):
133
+ inputs.update(
134
+ {
135
+ f"past_key_{i}": self.static_inputs_map[batch_size][f"past_key_{i}"].contiguous(),
136
+ f"past_value_{i}": self.static_inputs_map[batch_size][f"past_value_{i}"].contiguous(),
137
+ }
138
+ )
139
+
140
+ logits = torch.zeros(batch_size, sequence_length, 51200, device=self.device, dtype=self.torch_dtype)
141
+ outputs = {"logits": logits.contiguous()}
142
+
143
+ if not self.use_buffer_share:
144
+ present_shape = (
145
+ (2, batch_size, self.num_heads, sequence_length, self.head_size)
146
+ if self.packed_kv
147
+ else (batch_size, self.num_heads, sequence_length, self.head_size)
148
+ )
149
+ for i in range(self.num_layers):
150
+ present = torch.zeros(present_shape, device=self.device, dtype=self.torch_dtype)
151
+ (
152
+ outputs.update(
153
+ {f"present_key_{i}": present.contiguous(), f"present_value_{i}": present.contiguous()}
154
+ )
155
+ if not self.packed_kv
156
+ else outputs.update({f"present_{i}": present.contiguous()})
157
+ )
158
+
159
+ return inputs, outputs
160
+
161
+ def apply_io_binding(self, model: ort.InferenceSession, inputs: dict, outputs: dict):
162
+ io_binding = model.io_binding()
163
+ device = None
164
+
165
+ for k, v in inputs.items():
166
+ io_binding.bind_input(
167
+ name=k,
168
+ device_type=v.device.type,
169
+ device_id=0 if v.device.type == "cpu" else v.device.index,
170
+ element_type=pt_to_np[repr(v.dtype)],
171
+ shape=tuple(v.shape),
172
+ buffer_ptr=v.data_ptr(),
173
+ )
174
+ device = v.device
175
+
176
+ for output in model.get_outputs():
177
+ name = output.name
178
+ if self.use_buffer_share and "present" in name:
179
+ v = inputs[name.replace("present", "past")]
180
+ io_binding.bind_output(
181
+ name=name,
182
+ device_type=v.device.type,
183
+ device_id=v.device.index,
184
+ element_type=(np.float16 if self.use_fp16 else np.float32),
185
+ shape=tuple(v.shape),
186
+ buffer_ptr=v.data_ptr(),
187
+ )
188
+ else:
189
+ v = outputs[name]
190
+ io_binding.bind_output(
191
+ name=name,
192
+ device_type=device.type,
193
+ device_id=0 if device.type == "cpu" else device.index,
194
+ element_type=(np.float16 if self.use_fp16 else np.float32),
195
+ shape=tuple(v.shape),
196
+ buffer_ptr=v.data_ptr(),
197
+ )
198
+
199
+ return io_binding
200
+
201
+ def create_session(
202
+ self, device_id, use_fp16=True, use_buffer_share=True, packed_kv=False, use_step=False, use_cuda_graph=False
203
+ ):
204
+ self.device_id = device_id
205
+ sess_options = ort.SessionOptions()
206
+ sess_options.log_verbosity_level = 4
207
+ sess_options.log_severity_level = 4
208
+ self.use_cuda_graph = use_cuda_graph
209
+ ep = (
210
+ ("CUDAExecutionProvider", {"device_id": self.device_id, "enable_cuda_graph": self.use_cuda_graph})
211
+ if self.device_id >= 0
212
+ else "CPUExecutionProvider"
213
+ )
214
+ self.sess = ort.InferenceSession(self.onnx_decoder_path, sess_options=sess_options, providers=[ep])
215
+ self.ro = ort.RunOptions()
216
+
217
+ self.device = torch.device("cuda", self.device_id) if torch.cuda.is_available() else torch.device("cpu")
218
+ self.use_fp16 = use_fp16
219
+ self.use_buffer_share = use_buffer_share
220
+ self.packed_kv = packed_kv
221
+ self.use_step = use_step
222
+
223
+ self.tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True)
224
+ self.tokenizer.pad_token = "[PAD]"
225
+
226
+ def generate_impl(self, encodings_dict, max_length, cuda_graph_annotation, benchmark=False):
227
+ inputs, outputs = self.get_initial_inputs_and_outputs(encodings_dict)
228
+
229
+ all_token_ids = inputs["input_ids"].clone()
230
+ batch_size, sequence_length = all_token_ids.shape
231
+
232
+ current_length = sequence_length
233
+ has_eos = torch.zeros(batch_size, device=self.device, dtype=torch.bool)
234
+
235
+ if benchmark:
236
+ latency = []
237
+
238
+ prompt_run = True
239
+ while current_length < max_length:
240
+ io_binding = self.apply_io_binding(self.sess, inputs, outputs)
241
+
242
+ if benchmark:
243
+ start = time.time()
244
+
245
+ io_binding.synchronize_inputs()
246
+ if prompt_run:
247
+ if self.use_cuda_graph:
248
+ # Disable CUDA graph for the prompt run
249
+ self.ro.add_run_config_entry("gpu_graph_id", "-1")
250
+ self.sess.run_with_iobinding(io_binding, self.ro)
251
+ if self.use_cuda_graph:
252
+ # Enable CUDA graph for the decoding run
253
+ self.ro.add_run_config_entry(
254
+ "gpu_graph_id", str(cuda_graph_annotation) if self.use_traced_inputs else "-1"
255
+ )
256
+ prompt_run = False
257
+ else:
258
+ self.sess.run_with_iobinding(io_binding, self.ro)
259
+ io_binding.synchronize_outputs()
260
+
261
+ if benchmark:
262
+ end = time.time()
263
+ latency.append(end - start)
264
+
265
+ # Sample with argmax (greedy search)
266
+ next_token_logits = outputs["logits"][:, -1, :]
267
+ next_tokens = torch.argmax(next_token_logits, dim=-1)
268
+
269
+ # Check if we previously reached EOS token id or if generated token id is EOS token id
270
+ has_eos = has_eos | next_tokens == self.tokenizer.eos_token_id
271
+
272
+ # Determine which new tokens to add to list of all token ids
273
+ # Add EOS token ids for batch entries that ended early (ragged batching scenario where some batch entries ended early and some haven't)
274
+ tokens_to_add = next_tokens.masked_fill(has_eos, self.tokenizer.eos_token_id).reshape([batch_size, 1])
275
+ all_token_ids = torch.cat([all_token_ids, tokens_to_add], dim=-1)
276
+
277
+ # Return early if all batch entries have reached EOS token id
278
+ if torch.all(has_eos):
279
+ break
280
+
281
+ # Update inputs for next inference run
282
+ current_length += 1
283
+
284
+ inputs["input_ids"] = tokens_to_add.to(torch.int32)
285
+ if self.use_traced_inputs:
286
+ cuda_memcpy(self.static_inputs_map[batch_size]["input_ids"], inputs["input_ids"])
287
+ inputs["input_ids"] = self.static_inputs_map[batch_size]["input_ids"]
288
+
289
+ if self.use_step:
290
+ inputs["step"] = torch.tensor([current_length - 1], device=self.device, dtype=torch.int64)
291
+ if self.use_traced_inputs:
292
+ cuda_memcpy(self.static_inputs_map[batch_size]["step"], inputs["step"])
293
+ inputs["step"] = self.static_inputs_map[batch_size]["step"]
294
+
295
+ if self.use_cuda_graph:
296
+ previous_seqlens_k = inputs["seqlens_k"]
297
+ inputs["seqlens_k"] = (previous_seqlens_k + (~has_eos).reshape(batch_size, 1)).to(torch.int32)
298
+ inputs["total_sequence_length"][0] = current_length
299
+ if self.use_traced_inputs:
300
+ cuda_memcpy(self.static_inputs_map[batch_size]["seqlens_k"], inputs["seqlens_k"])
301
+ inputs["seqlens_k"] = self.static_inputs_map[batch_size]["seqlens_k"]
302
+ self.static_inputs_map[batch_size]["total_sequence_length"][0] = inputs["total_sequence_length"][0]
303
+ inputs["total_sequence_length"] = self.static_inputs_map[batch_size]["total_sequence_length"]
304
+ else:
305
+ inputs["attention_mask"] = torch.cat(
306
+ [inputs["attention_mask"], (~has_eos).reshape(batch_size, 1)], 1
307
+ ).to(torch.int32)
308
+
309
+ # Set logits to zeros for next inference run and re-use memory buffer
310
+ if outputs["logits"].shape[1] != 1:
311
+ outputs["logits"] = outputs["logits"][:, :1, :].contiguous()
312
+ if self.use_traced_inputs:
313
+ outputs["logits"] = self.static_inputs_map[batch_size]["logits"]
314
+ outputs["logits"].zero_()
315
+
316
+ if not self.use_buffer_share:
317
+ for i in range(self.num_layers):
318
+ if not self.packed_kv:
319
+ inputs[f"past_key_{i}"] = outputs[f"present_key_{i}"]
320
+ inputs[f"past_value_{i}"] = outputs[f"present_value_{i}"]
321
+ else:
322
+ inputs[f"past_{i}"] = outputs[f"present_{i}"]
323
+
324
+ new_sequence_length = inputs["attention_mask"].shape[1]
325
+ present_shape = (
326
+ (2, batch_size, self.num_heads, new_sequence_length, self.head_size)
327
+ if self.packed_kv
328
+ else (batch_size, self.num_heads, new_sequence_length, self.head_size)
329
+ )
330
+ for i in range(self.num_layers):
331
+ present = torch.zeros(present_shape, device=self.device, dtype=self.torch_dtype)
332
+ (
333
+ outputs.update(
334
+ {
335
+ f"present_key_{i}": present.contiguous(),
336
+ f"present_value_{i}": present.clone().contiguous(),
337
+ }
338
+ )
339
+ if not self.packed_kv
340
+ else outputs.update({f"present_{i}": present.contiguous()})
341
+ )
342
+
343
+ if benchmark:
344
+ print(
345
+ f"Batch size: {batch_size}, Sequence length: {sequence_length}, Token num: {max_length - sequence_length}"
346
+ )
347
+ print(f"Prompt letency: {1000 * latency[0]}ms, Token latency: {1000 * np.mean(latency[1:])}ms")
348
+ return
349
+
350
+ texts = self.tokenizer.batch_decode(all_token_ids, skip_special_tokens=True)
351
+ return texts
352
+
353
+ def generate(self, prompt, max_length, cuda_graph_annotation):
354
+ encodings_dict = self.tokenizer.batch_encode_plus(prompt, padding=True)
355
+
356
+ return self.generate_impl(encodings_dict, max_length, cuda_graph_annotation)
357
+
358
+ def generate_benchmark(self, prompt_shape, token_num, cuda_graph_annotation):
359
+ batch_size, sequence_length = prompt_shape
360
+ max_length = sequence_length + token_num
361
+
362
+ encodings_dict = {}
363
+ encodings_dict["input_ids"] = torch.randint(0, 50264, (batch_size, sequence_length), dtype=torch.int32).tolist()
364
+ encodings_dict["attention_mask"] = torch.ones((batch_size, sequence_length), dtype=torch.int32).tolist()
365
+
366
+ # Warm up run
367
+ self.generate_impl(encodings_dict, max_length, cuda_graph_annotation, benchmark=False)
368
+
369
+ # Benchmark run
370
+ self.generate_impl(encodings_dict, max_length, cuda_graph_annotation, benchmark=True)
371
+
372
+
373
+ def run_phi2(
374
+ onnx_model_path,
375
+ use_buffer_share,
376
+ device_id,
377
+ packed_kv=False,
378
+ use_fp16=True,
379
+ use_step=False,
380
+ use_cuda_graph=False,
381
+ run_benchmark=False,
382
+ ):
383
+ generator = ORTGenerator(onnx_model_path)
384
+ generator.create_session(device_id, use_fp16, use_buffer_share, packed_kv, use_step, use_cuda_graph)
385
+
386
+ def simple_run(prompt):
387
+ example_batch_size = len(prompt)
388
+ if use_cuda_graph:
389
+ generator.append_static_inputs(batch_size=example_batch_size)
390
+ texts = generator.generate(prompt, max_length=210, cuda_graph_annotation=example_batch_size)
391
+
392
+ for i in range(len(texts)):
393
+ print("Prompt: ", prompt[i])
394
+ print("Texts: ", texts[i])
395
+
396
+ prompt = [
397
+ '''```python
398
+ def print_prime(n):
399
+ """
400
+ Print all primes between 1 and n
401
+ """'''
402
+ ]
403
+
404
+ if not run_benchmark:
405
+ simple_run(prompt)
406
+
407
+ # Run simple benchmark. Time the decoder only.
408
+ if run_benchmark:
409
+ token_num = 32
410
+ for batch_size in [1, 2, 4, 8]:
411
+ generator.append_static_inputs(batch_size)
412
+ for sequence_length in [16, 512]:
413
+ prompt_shape = (batch_size, sequence_length)
414
+ generator.generate_benchmark(prompt_shape, token_num, cuda_graph_annotation=batch_size)
@@ -0,0 +1,12 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+ import os.path
6
+ import sys
7
+
8
+ sys.path.append(os.path.dirname(__file__))
9
+
10
+ transformers_dir = os.path.normpath(os.path.join(os.path.dirname(__file__), "..", ".."))
11
+ if transformers_dir not in sys.path:
12
+ sys.path.append(transformers_dir)