onnxruntime-directml 1.20.0__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (305) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6508 -0
  4. onnxruntime/__init__.py +78 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +174 -0
  7. onnxruntime/backend/backend_rep.py +53 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  13. onnxruntime/capi/onnxruntime.dll +0 -0
  14. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  15. onnxruntime/capi/onnxruntime_inference_collection.py +1108 -0
  16. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  17. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  18. onnxruntime/capi/onnxruntime_validation.py +150 -0
  19. onnxruntime/capi/version_info.py +2 -0
  20. onnxruntime/datasets/__init__.py +17 -0
  21. onnxruntime/datasets/logreg_iris.onnx +0 -0
  22. onnxruntime/datasets/mul_1.onnx +0 -0
  23. onnxruntime/datasets/sigmoid.onnx +13 -0
  24. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  27. onnxruntime/quantization/__init__.py +16 -0
  28. onnxruntime/quantization/base_quantizer.py +532 -0
  29. onnxruntime/quantization/calibrate.py +1245 -0
  30. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  31. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  32. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  33. onnxruntime/quantization/execution_providers/qnn/preprocess.py +307 -0
  34. onnxruntime/quantization/execution_providers/qnn/quant_config.py +387 -0
  35. onnxruntime/quantization/fusions/__init__.py +3 -0
  36. onnxruntime/quantization/fusions/fusion.py +311 -0
  37. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  38. onnxruntime/quantization/fusions/fusion_layernorm.py +135 -0
  39. onnxruntime/quantization/matmul_4bits_quantizer.py +1480 -0
  40. onnxruntime/quantization/matmul_bnb4_quantizer.py +240 -0
  41. onnxruntime/quantization/onnx_model.py +580 -0
  42. onnxruntime/quantization/onnx_quantizer.py +1008 -0
  43. onnxruntime/quantization/operators/__init__.py +2 -0
  44. onnxruntime/quantization/operators/activation.py +119 -0
  45. onnxruntime/quantization/operators/argmax.py +18 -0
  46. onnxruntime/quantization/operators/attention.py +73 -0
  47. onnxruntime/quantization/operators/base_operator.py +26 -0
  48. onnxruntime/quantization/operators/binary_op.py +72 -0
  49. onnxruntime/quantization/operators/concat.py +62 -0
  50. onnxruntime/quantization/operators/conv.py +258 -0
  51. onnxruntime/quantization/operators/direct_q8.py +78 -0
  52. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  53. onnxruntime/quantization/operators/gather.py +64 -0
  54. onnxruntime/quantization/operators/gavgpool.py +62 -0
  55. onnxruntime/quantization/operators/gemm.py +166 -0
  56. onnxruntime/quantization/operators/lstm.py +117 -0
  57. onnxruntime/quantization/operators/matmul.py +231 -0
  58. onnxruntime/quantization/operators/maxpool.py +34 -0
  59. onnxruntime/quantization/operators/norm.py +40 -0
  60. onnxruntime/quantization/operators/pad.py +100 -0
  61. onnxruntime/quantization/operators/pooling.py +67 -0
  62. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  63. onnxruntime/quantization/operators/resize.py +34 -0
  64. onnxruntime/quantization/operators/softmax.py +74 -0
  65. onnxruntime/quantization/operators/split.py +63 -0
  66. onnxruntime/quantization/operators/where.py +87 -0
  67. onnxruntime/quantization/preprocess.py +141 -0
  68. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  69. onnxruntime/quantization/qdq_quantizer.py +1187 -0
  70. onnxruntime/quantization/quant_utils.py +891 -0
  71. onnxruntime/quantization/quantize.py +748 -0
  72. onnxruntime/quantization/registry.py +106 -0
  73. onnxruntime/quantization/shape_inference.py +187 -0
  74. onnxruntime/quantization/tensor_quant_overrides.py +516 -0
  75. onnxruntime/tools/__init__.py +10 -0
  76. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  77. onnxruntime/tools/convert_onnx_models_to_ort.py +377 -0
  78. onnxruntime/tools/file_utils.py +46 -0
  79. onnxruntime/tools/logger.py +11 -0
  80. onnxruntime/tools/make_dynamic_shape_fixed.py +72 -0
  81. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  82. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +33 -0
  83. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  84. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  85. onnxruntime/tools/mobile_helpers/usability_checker.py +739 -0
  86. onnxruntime/tools/offline_tuning.py +169 -0
  87. onnxruntime/tools/onnx_model_utils.py +413 -0
  88. onnxruntime/tools/onnx_randomizer.py +85 -0
  89. onnxruntime/tools/onnxruntime_test.py +164 -0
  90. onnxruntime/tools/optimize_onnx_model.py +55 -0
  91. onnxruntime/tools/ort_format_model/__init__.py +25 -0
  92. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +663 -0
  93. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  94. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  95. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  96. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  97. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  98. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  99. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  100. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  140. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  141. onnxruntime/tools/ort_format_model/types.py +84 -0
  142. onnxruntime/tools/ort_format_model/utils.py +62 -0
  143. onnxruntime/tools/pytorch_export_contrib_ops.py +108 -0
  144. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  145. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  146. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  147. onnxruntime/tools/reduced_build_config_parser.py +202 -0
  148. onnxruntime/tools/symbolic_shape_infer.py +3016 -0
  149. onnxruntime/tools/update_onnx_opset.py +31 -0
  150. onnxruntime/transformers/__init__.py +8 -0
  151. onnxruntime/transformers/affinity_helper.py +40 -0
  152. onnxruntime/transformers/benchmark.py +944 -0
  153. onnxruntime/transformers/benchmark_helper.py +646 -0
  154. onnxruntime/transformers/bert_perf_test.py +634 -0
  155. onnxruntime/transformers/bert_test_data.py +642 -0
  156. onnxruntime/transformers/compare_bert_results.py +246 -0
  157. onnxruntime/transformers/constants.py +47 -0
  158. onnxruntime/transformers/convert_generation.py +3124 -0
  159. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  160. onnxruntime/transformers/convert_to_packing_mode.py +387 -0
  161. onnxruntime/transformers/dynamo_onnx_helper.py +104 -0
  162. onnxruntime/transformers/float16.py +501 -0
  163. onnxruntime/transformers/fusion_attention.py +1235 -0
  164. onnxruntime/transformers/fusion_attention_clip.py +257 -0
  165. onnxruntime/transformers/fusion_attention_sam2.py +534 -0
  166. onnxruntime/transformers/fusion_attention_unet.py +1304 -0
  167. onnxruntime/transformers/fusion_attention_vae.py +301 -0
  168. onnxruntime/transformers/fusion_bart_attention.py +640 -0
  169. onnxruntime/transformers/fusion_base.py +137 -0
  170. onnxruntime/transformers/fusion_bias_add.py +58 -0
  171. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  172. onnxruntime/transformers/fusion_biassplitgelu.py +111 -0
  173. onnxruntime/transformers/fusion_conformer_attention.py +143 -0
  174. onnxruntime/transformers/fusion_embedlayer.py +811 -0
  175. onnxruntime/transformers/fusion_fastgelu.py +360 -0
  176. onnxruntime/transformers/fusion_gelu.py +259 -0
  177. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  178. onnxruntime/transformers/fusion_gemmfastgelu.py +122 -0
  179. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  180. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  181. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  182. onnxruntime/transformers/fusion_group_norm.py +179 -0
  183. onnxruntime/transformers/fusion_layernorm.py +465 -0
  184. onnxruntime/transformers/fusion_nhwc_conv.py +100 -0
  185. onnxruntime/transformers/fusion_options.py +340 -0
  186. onnxruntime/transformers/fusion_qordered_attention.py +421 -0
  187. onnxruntime/transformers/fusion_qordered_gelu.py +119 -0
  188. onnxruntime/transformers/fusion_qordered_layernorm.py +123 -0
  189. onnxruntime/transformers/fusion_qordered_matmul.py +217 -0
  190. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  191. onnxruntime/transformers/fusion_reshape.py +173 -0
  192. onnxruntime/transformers/fusion_rotary_attention.py +1592 -0
  193. onnxruntime/transformers/fusion_shape.py +110 -0
  194. onnxruntime/transformers/fusion_simplified_layernorm.py +159 -0
  195. onnxruntime/transformers/fusion_skip_group_norm.py +255 -0
  196. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  197. onnxruntime/transformers/fusion_transpose.py +168 -0
  198. onnxruntime/transformers/fusion_utils.py +307 -0
  199. onnxruntime/transformers/huggingface_models.py +167 -0
  200. onnxruntime/transformers/import_utils.py +20 -0
  201. onnxruntime/transformers/io_binding_helper.py +442 -0
  202. onnxruntime/transformers/large_model_exporter.py +395 -0
  203. onnxruntime/transformers/machine_info.py +221 -0
  204. onnxruntime/transformers/metrics.py +164 -0
  205. onnxruntime/transformers/models/bart/__init__.py +12 -0
  206. onnxruntime/transformers/models/bart/export.py +98 -0
  207. onnxruntime/transformers/models/bert/__init__.py +12 -0
  208. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  209. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  210. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  211. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +561 -0
  212. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1032 -0
  213. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  214. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  215. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  216. onnxruntime/transformers/models/llama/__init__.py +12 -0
  217. onnxruntime/transformers/models/llama/benchmark.py +703 -0
  218. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  219. onnxruntime/transformers/models/llama/benchmark_e2e.py +606 -0
  220. onnxruntime/transformers/models/llama/convert_to_onnx.py +1027 -0
  221. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  222. onnxruntime/transformers/models/llama/llama_inputs.py +503 -0
  223. onnxruntime/transformers/models/llama/llama_parity.py +309 -0
  224. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  225. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  226. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  227. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  228. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  229. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  230. onnxruntime/transformers/models/longformer/longformer_helper.py +77 -0
  231. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  232. onnxruntime/transformers/models/phi2/convert_to_onnx.py +576 -0
  233. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  234. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  235. onnxruntime/transformers/models/sam2/benchmark_sam2.py +625 -0
  236. onnxruntime/transformers/models/sam2/convert_to_onnx.py +260 -0
  237. onnxruntime/transformers/models/sam2/image_decoder.py +273 -0
  238. onnxruntime/transformers/models/sam2/image_encoder.py +186 -0
  239. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  240. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  241. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  242. onnxruntime/transformers/models/sam2/sam2_demo.py +322 -0
  243. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +280 -0
  244. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  245. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  246. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1429 -0
  247. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  248. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +102 -0
  249. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +268 -0
  250. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  251. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1319 -0
  252. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1181 -0
  253. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +296 -0
  254. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +388 -0
  255. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  256. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  257. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  258. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +350 -0
  259. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  260. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  261. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  262. onnxruntime/transformers/models/t5/__init__.py +12 -0
  263. onnxruntime/transformers/models/t5/convert_to_onnx.py +278 -0
  264. onnxruntime/transformers/models/t5/past_helper.py +150 -0
  265. onnxruntime/transformers/models/t5/t5_decoder.py +438 -0
  266. onnxruntime/transformers/models/t5/t5_encoder.py +171 -0
  267. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +299 -0
  268. onnxruntime/transformers/models/t5/t5_helper.py +272 -0
  269. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  270. onnxruntime/transformers/models/whisper/benchmark.py +610 -0
  271. onnxruntime/transformers/models/whisper/benchmark_all.py +528 -0
  272. onnxruntime/transformers/models/whisper/convert_to_onnx.py +536 -0
  273. onnxruntime/transformers/models/whisper/whisper_chain.py +329 -0
  274. onnxruntime/transformers/models/whisper/whisper_decoder.py +402 -0
  275. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  276. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +306 -0
  277. onnxruntime/transformers/models/whisper/whisper_helper.py +524 -0
  278. onnxruntime/transformers/models/whisper/whisper_openai_helper.py +84 -0
  279. onnxruntime/transformers/onnx_exporter.py +717 -0
  280. onnxruntime/transformers/onnx_model.py +1569 -0
  281. onnxruntime/transformers/onnx_model_bart.py +142 -0
  282. onnxruntime/transformers/onnx_model_bert.py +481 -0
  283. onnxruntime/transformers/onnx_model_bert_keras.py +475 -0
  284. onnxruntime/transformers/onnx_model_bert_tf.py +589 -0
  285. onnxruntime/transformers/onnx_model_clip.py +40 -0
  286. onnxruntime/transformers/onnx_model_conformer.py +33 -0
  287. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  288. onnxruntime/transformers/onnx_model_phi.py +930 -0
  289. onnxruntime/transformers/onnx_model_sam2.py +138 -0
  290. onnxruntime/transformers/onnx_model_t5.py +791 -0
  291. onnxruntime/transformers/onnx_model_tnlr.py +227 -0
  292. onnxruntime/transformers/onnx_model_unet.py +259 -0
  293. onnxruntime/transformers/onnx_model_vae.py +43 -0
  294. onnxruntime/transformers/onnx_utils.py +55 -0
  295. onnxruntime/transformers/optimizer.py +612 -0
  296. onnxruntime/transformers/profiler.py +725 -0
  297. onnxruntime/transformers/quantize_helper.py +76 -0
  298. onnxruntime/transformers/shape_infer_helper.py +122 -0
  299. onnxruntime/transformers/shape_optimizer.py +401 -0
  300. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  301. onnxruntime_directml-1.20.0.dist-info/METADATA +187 -0
  302. onnxruntime_directml-1.20.0.dist-info/RECORD +305 -0
  303. onnxruntime_directml-1.20.0.dist-info/WHEEL +5 -0
  304. onnxruntime_directml-1.20.0.dist-info/entry_points.txt +2 -0
  305. onnxruntime_directml-1.20.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,528 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License. See License.txt in the project root for
4
+ # license information.
5
+ # --------------------------------------------------------------------------
6
+
7
+ import argparse
8
+ import datetime
9
+ import json
10
+ import logging
11
+ import os
12
+ import subprocess
13
+
14
+ import librosa
15
+ import torch
16
+ from benchmark_helper import setup_logger
17
+ from metrics import BenchmarkRecord
18
+ from transformers import WhisperConfig, WhisperProcessor
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ def get_args():
24
+ parser = argparse.ArgumentParser()
25
+
26
+ parser.add_argument(
27
+ "-a",
28
+ "--audio-path",
29
+ type=str,
30
+ required=True,
31
+ help="Path to folder of audio files for E2E evaluation",
32
+ )
33
+
34
+ parser.add_argument(
35
+ "-l",
36
+ "--language",
37
+ default=None,
38
+ help="Language of audio file",
39
+ )
40
+
41
+ parser.add_argument(
42
+ "-t",
43
+ "--task",
44
+ default=None,
45
+ choices=["transcribe", "translate"],
46
+ help="Task to complete",
47
+ )
48
+
49
+ parser.add_argument(
50
+ "-w",
51
+ "--warmup-runs",
52
+ type=int,
53
+ default=5,
54
+ )
55
+
56
+ parser.add_argument(
57
+ "-n",
58
+ "--num-runs",
59
+ type=int,
60
+ default=10,
61
+ )
62
+
63
+ parser.add_argument(
64
+ "--hf-pt-eager",
65
+ default=False,
66
+ action="store_true",
67
+ help="Benchmark in PyTorch without `torch.compile`",
68
+ )
69
+
70
+ parser.add_argument(
71
+ "--hf-pt-compile",
72
+ default=False,
73
+ action="store_true",
74
+ help="Benchmark in PyTorch with `torch.compile`",
75
+ )
76
+
77
+ parser.add_argument(
78
+ "--hf-ort-dir-path",
79
+ type=str,
80
+ help="Path to folder containing ONNX models for Optimum + ORT benchmarking",
81
+ )
82
+
83
+ parser.add_argument(
84
+ "--ort-model-path",
85
+ type=str,
86
+ help="Path to ONNX model for ORT benchmarking",
87
+ )
88
+
89
+ parser.add_argument(
90
+ "--model-name",
91
+ type=str,
92
+ required=True,
93
+ help="Model name in Hugging Face (e.g. openai/whisper-large-v2)",
94
+ )
95
+
96
+ parser.add_argument(
97
+ "--precision",
98
+ type=str,
99
+ required=True,
100
+ choices=["int8", "fp16", "fp32"],
101
+ help="Precision to run model",
102
+ )
103
+
104
+ parser.add_argument(
105
+ "--device",
106
+ type=str,
107
+ required=True,
108
+ choices=["cpu", "cuda", "rocm"],
109
+ help="Device to benchmark models",
110
+ )
111
+
112
+ parser.add_argument(
113
+ "--device-id",
114
+ type=int,
115
+ default=0,
116
+ help="GPU device ID",
117
+ )
118
+
119
+ parser.add_argument(
120
+ "--verbose",
121
+ default=False,
122
+ action="store_true",
123
+ help="Print detailed logs",
124
+ )
125
+
126
+ parser.add_argument(
127
+ "--timeout",
128
+ type=int,
129
+ default=5,
130
+ help="Number of mins to attempt the benchmark before moving on",
131
+ )
132
+
133
+ parser.add_argument(
134
+ "--log-folder",
135
+ type=str,
136
+ default=None,
137
+ help="Path to folder to save logs and results",
138
+ )
139
+
140
+ parser.add_argument("--tune", default=False, action="store_true")
141
+
142
+ args = parser.parse_args()
143
+
144
+ setattr(args, "model_size", args.model_name.split("/")[-1].replace(".", "-")) # noqa: B010
145
+ log_folder_name = f"./{args.model_size}-{args.precision}"
146
+ if not args.log_folder:
147
+ args.log_folder = log_folder_name
148
+ os.makedirs(args.log_folder, exist_ok=True)
149
+
150
+ # Convert timeout value to secs
151
+ args.timeout *= 60
152
+
153
+ return args
154
+
155
+
156
+ def process_log_file(device_id, log_file, base_results):
157
+ entries = []
158
+
159
+ # Detect steps in speech pipeline
160
+ step = None
161
+ load_audio_pattern = "Load audio: "
162
+ feat_ext_pattern = "Feature extraction: "
163
+ pytorch_pattern = "Evaluating PyTorch..."
164
+ onnxruntime_pattern = "Evaluating ONNX Runtime..."
165
+
166
+ load_audio_latency_s, load_audio_throughput_s = None, None
167
+ feat_ext_latency_s, feat_ext_throughput_s = None, None
168
+ token_length, latency_s, per_token_latency_s, per_token_latency_ms = None, None, None, None
169
+ throughput, memory = None, None
170
+
171
+ # Detect metrics
172
+ latency_pattern = "Latency: "
173
+ throughput_pattern = "Throughput: "
174
+ token_length_pattern = "Generated token length: "
175
+ memory_pattern = "peak="
176
+
177
+ with open(log_file) as f:
178
+ for input_line in f:
179
+ line = input_line.replace("\n", "")
180
+
181
+ # Get step in speech recognition pipeline
182
+ if load_audio_pattern in line:
183
+ step = "load-audio"
184
+ elif feat_ext_pattern in line:
185
+ step = "feature-extraction"
186
+ elif pytorch_pattern in line or onnxruntime_pattern in line:
187
+ step = "process"
188
+
189
+ # Check metrics
190
+ if latency_pattern in line:
191
+ latency_s = float(line[len(latency_pattern) : line.rfind(" ")])
192
+ elif throughput_pattern in line:
193
+ throughput = float(line[len(throughput_pattern) : line.rfind(" ")])
194
+ if step == "load-audio":
195
+ load_audio_latency_s, load_audio_throughput_s = latency_s, throughput
196
+ step = None
197
+ if step == "feature-extraction":
198
+ feat_ext_latency_s, feat_ext_throughput_s = latency_s, throughput
199
+ step = None
200
+ elif token_length_pattern in line:
201
+ token_length = int(line[len(token_length_pattern) : line.rfind(" ")])
202
+ per_token_latency_s = latency_s / token_length
203
+ per_token_latency_ms = per_token_latency_s * 1000
204
+ elif memory_pattern in line:
205
+ if "CPU" in line:
206
+ # Example format for log entry:
207
+ # CPU memory usage: before=1000.0 MB, peak=2000.0 MB
208
+ memory = float(line[line.rfind("=") + 1 : line.rfind(" MB")]) / 1000
209
+ else:
210
+ # Example format for log entry:
211
+ # GPU memory usage: before=[{'device_id': 0, 'name': 'Tesla V100-PCIE-16GB', 'max_used_MB': 1638.875}, {'device_id': 1, 'name': 'Tesla V100-PCIE-16GB', 'max_used_MB': 236.875}, peak=[{'device_id': 0, 'name': 'Tesla V100-PCIE-16GB', 'max_used_MB': 1780.875}, {'device_id': 1, 'name': 'Tesla V100-PCIE-16GB', 'max_used_MB': 236.875}]
212
+ peak = line[line.find(memory_pattern) + len(memory_pattern) :].replace("'", '"')
213
+ usage = json.loads(peak)[device_id]["max_used_MB"]
214
+ memory = float(usage) / 1000
215
+
216
+ # Calculate real-time factor (RTF):
217
+ # RTF = total latency / audio duration
218
+ total_latency = (
219
+ (load_audio_latency_s if load_audio_latency_s else 0)
220
+ + (feat_ext_latency_s if feat_ext_latency_s else 0)
221
+ + (latency_s if latency_s else 0)
222
+ )
223
+ audio_duration = base_results[-1]
224
+ rtf = (total_latency / audio_duration) if audio_duration else -1
225
+ logger.info(f"Total latency: {total_latency} s")
226
+ logger.info(f"Audio duration: {audio_duration} s")
227
+ logger.info(f"Real-time factor: {rtf}")
228
+
229
+ # Append log entry to list of entries
230
+ entry = base_results + [ # noqa: RUF005
231
+ token_length,
232
+ load_audio_latency_s,
233
+ load_audio_throughput_s,
234
+ feat_ext_latency_s if feat_ext_latency_s else -1,
235
+ feat_ext_throughput_s if feat_ext_throughput_s else -1,
236
+ latency_s,
237
+ per_token_latency_ms,
238
+ throughput,
239
+ memory,
240
+ rtf,
241
+ ]
242
+ entries.append(entry)
243
+
244
+ return entries
245
+
246
+
247
+ def save_results(results, filename):
248
+ import pandas as pd
249
+
250
+ df = pd.DataFrame(
251
+ results,
252
+ columns=[
253
+ "Warmup Runs",
254
+ "Measured Runs",
255
+ "Model Name",
256
+ "Engine",
257
+ "Precision",
258
+ "Device",
259
+ "Audio File",
260
+ "Duration (s)",
261
+ "Token Length",
262
+ "Load Audio Latency (s)",
263
+ "Load Audio Throughput (qps)",
264
+ "Feature Extractor Latency (s)",
265
+ "Feature Extractor Throughput (qps)",
266
+ "Latency (s)",
267
+ "Per Token Latency (ms/token)",
268
+ "Throughput (qps)",
269
+ "Memory (GB)",
270
+ "Real Time Factor (RTF)",
271
+ ],
272
+ )
273
+
274
+ # Set column types
275
+ df["Warmup Runs"] = df["Warmup Runs"].astype("int")
276
+ df["Measured Runs"] = df["Measured Runs"].astype("int")
277
+ df["Duration (s)"] = df["Duration (s)"].astype("float")
278
+ df["Token Length"] = df["Token Length"].astype("int")
279
+ df["Load Audio Latency (s)"] = df["Load Audio Latency (s)"].astype("float")
280
+ df["Load Audio Throughput (qps)"] = df["Load Audio Throughput (qps)"].astype("float")
281
+ df["Feature Extractor Latency (s)"] = df["Feature Extractor Latency (s)"].astype("float")
282
+ df["Feature Extractor Throughput (qps)"] = df["Feature Extractor Throughput (qps)"].astype("float")
283
+ df["Latency (s)"] = df["Latency (s)"].astype("float")
284
+ df["Per Token Latency (ms/token)"] = df["Per Token Latency (ms/token)"].astype("float")
285
+ df["Throughput (qps)"] = df["Throughput (qps)"].astype("float")
286
+ df["Memory (GB)"] = df["Memory (GB)"].astype("float")
287
+ df["Real Time Factor (RTF)"] = df["Real Time Factor (RTF)"].astype("float")
288
+
289
+ # get package name and version
290
+ import pkg_resources
291
+
292
+ installed_packages = pkg_resources.working_set
293
+ installed_packages_list = sorted(
294
+ [f"{i.key}=={i.version}" for i in installed_packages if i.key in ["onnxruntime", "onnxruntime-gpu"]]
295
+ )
296
+ ort_pkg_name = ""
297
+ ort_pkg_version = ""
298
+ if installed_packages_list:
299
+ ort_pkg_name = installed_packages_list[0].split("==")[0]
300
+ ort_pkg_version = installed_packages_list[0].split("==")[1]
301
+
302
+ # Save results to csv with standard format
303
+ records = []
304
+ for _, row in df.iterrows():
305
+ if row["Engine"] == "onnxruntime":
306
+ record = BenchmarkRecord(
307
+ row["Model Name"], row["Precision"], row["Engine"], row["Device"], ort_pkg_name, ort_pkg_version
308
+ )
309
+ else:
310
+ record = BenchmarkRecord(
311
+ row["Model Name"], row["Precision"], row["Engine"], row["Device"], torch.__name__, torch.__version__
312
+ )
313
+ record.config.customized["audio_file"] = row["Audio File"]
314
+ record.config.warmup_runs = row["Warmup Runs"]
315
+ record.config.measured_runs = row["Measured Runs"]
316
+
317
+ record.metrics.customized["duration"] = row["Duration (s)"]
318
+ record.metrics.customized["token_length"] = row["Token Length"]
319
+ record.metrics.customized["load_audio_latency"] = row["Load Audio Latency (s)"]
320
+ record.metrics.customized["load_audio_throughput"] = row["Load Audio Throughput (qps)"]
321
+ record.metrics.customized["feature_extractor_latency_s"] = row["Feature Extractor Latency (s)"]
322
+ record.metrics.customized["feature_extractor_throughput_qps"] = row["Feature Extractor Throughput (qps)"]
323
+ record.metrics.customized["per_token_latency_ms"] = row["Per Token Latency (ms/token)"]
324
+ record.metrics.customized["rtf"] = row["Real Time Factor (RTF)"]
325
+
326
+ record.metrics.latency_ms_mean = row["Latency (s)"] * 1000
327
+ record.metrics.throughput_qps = row["Throughput (qps)"]
328
+ record.metrics.max_memory_usage_GB = row["Memory (GB)"]
329
+
330
+ records.append(record)
331
+
332
+ BenchmarkRecord.save_as_csv(filename, records)
333
+ BenchmarkRecord.save_as_json(filename.replace(".csv", ".json"), records)
334
+ logger.info(f"Results saved in {filename}!")
335
+
336
+
337
+ def benchmark(args, benchmark_cmd, engine, audio_file, duration):
338
+ log_filename = f"{engine}_{datetime.datetime.now():%Y-%m-%d_%H:%M:%S}.log"
339
+ log_path = os.path.join(args.log_folder, log_filename)
340
+ with open(log_path, "w") as log_file:
341
+ process = subprocess.Popen(benchmark_cmd, stdout=log_file, stderr=log_file)
342
+ try:
343
+ process.wait(args.timeout)
344
+ except subprocess.TimeoutExpired:
345
+ process.kill()
346
+
347
+ # Create entries for csv
348
+ logger.info("Gathering data from log files...")
349
+ base_results = [
350
+ args.warmup_runs,
351
+ args.num_runs,
352
+ args.model_name,
353
+ engine,
354
+ args.precision,
355
+ args.device,
356
+ audio_file,
357
+ duration,
358
+ ]
359
+ results = process_log_file(args.device_id, log_path, base_results)
360
+
361
+ return results
362
+
363
+
364
+ def main():
365
+ args = get_args()
366
+ setup_logger(args.verbose)
367
+ logger.info(args.__dict__)
368
+ torch.backends.cudnn.benchmark = True
369
+
370
+ config = WhisperConfig.from_pretrained(args.model_name)
371
+ processor = WhisperProcessor.from_pretrained(args.model_name)
372
+
373
+ # Calculate forced decoder input ids
374
+ hf_forced_decoder_ids = processor.get_decoder_prompt_ids(language=args.language, task=args.task)
375
+ ort_forced_decoder_ids = [config.decoder_start_token_id] + list( # noqa: RUF005
376
+ map(lambda token_id: token_id[1], hf_forced_decoder_ids)
377
+ )
378
+ hf_decoder_input_ids_cmd = (
379
+ ["--decoder-input-ids", str(hf_forced_decoder_ids)] if args.language and args.task else []
380
+ )
381
+ ort_decoder_input_ids_cmd = (
382
+ ["--decoder-input-ids", str(ort_forced_decoder_ids)] if args.language and args.task else []
383
+ )
384
+ ort_tune_cmd = ["--tune"] if args.tune else []
385
+
386
+ all_results = []
387
+ for audio_file in os.listdir(args.audio_path):
388
+ audio_path = os.path.join(args.audio_path, audio_file)
389
+ try:
390
+ duration = librosa.get_duration(path=audio_path)
391
+ except Exception as e:
392
+ duration = -1
393
+ logger.warning(f"An error occurred while trying to calculate the audio duration: {e}", exc_info=True)
394
+ logger.warning(
395
+ f"If you get an error that says:\n\tsoundfile.LibsndfileError: Error opening '{audio_file}': File contains data in an unknown format.\nyou may not have installed `ffmpeg` in addition to installing `librosa`."
396
+ )
397
+ logger.info(f"Testing {audio_path}...")
398
+
399
+ # Benchmark PyTorch without torch.compile
400
+ if args.hf_pt_eager:
401
+ benchmark_cmd = [ # noqa: RUF005
402
+ "python",
403
+ "-m",
404
+ "models.whisper.benchmark",
405
+ "--audio-path",
406
+ audio_path,
407
+ "--benchmark-type",
408
+ "hf-pt-eager",
409
+ "--model-name",
410
+ args.model_name,
411
+ "--precision",
412
+ args.precision,
413
+ "--device",
414
+ args.device,
415
+ "--device-id",
416
+ str(args.device_id),
417
+ "--warmup-runs",
418
+ str(args.warmup_runs),
419
+ "--num-runs",
420
+ str(args.num_runs),
421
+ "--log-folder",
422
+ args.log_folder,
423
+ ] + hf_decoder_input_ids_cmd
424
+ logger.info("Benchmark PyTorch without torch.compile")
425
+ results = benchmark(args, benchmark_cmd, "pytorch-eager", audio_file, duration)
426
+ all_results.extend(results)
427
+
428
+ # Benchmark PyTorch with torch.compile
429
+ if args.hf_pt_compile:
430
+ benchmark_cmd = [ # noqa: RUF005
431
+ "python",
432
+ "-m",
433
+ "models.whisper.benchmark",
434
+ "--audio-path",
435
+ audio_path,
436
+ "--benchmark-type",
437
+ "hf-pt-compile",
438
+ "--model-name",
439
+ args.model_name,
440
+ "--precision",
441
+ args.precision,
442
+ "--device",
443
+ args.device,
444
+ "--device-id",
445
+ str(args.device_id),
446
+ "--warmup-runs",
447
+ str(args.warmup_runs),
448
+ "--num-runs",
449
+ str(args.num_runs),
450
+ "--log-folder",
451
+ args.log_folder,
452
+ ] + hf_decoder_input_ids_cmd
453
+ logger.info("Benchmark PyTorch with torch.compile")
454
+ results = benchmark(args, benchmark_cmd, "pytorch-compile", audio_file, duration)
455
+ all_results.extend(results)
456
+
457
+ # Benchmark Optimum + ONNX Runtime
458
+ if args.hf_ort_dir_path:
459
+ benchmark_cmd = [ # noqa: RUF005
460
+ "python",
461
+ "-m",
462
+ "models.whisper.benchmark",
463
+ "--audio-path",
464
+ audio_path,
465
+ "--benchmark-type",
466
+ "hf-ort",
467
+ "--hf-ort-dir-path",
468
+ args.hf_ort_dir_path,
469
+ "--model-name",
470
+ args.model_name,
471
+ "--precision",
472
+ args.precision,
473
+ "--device",
474
+ args.device,
475
+ "--device-id",
476
+ str(args.device_id),
477
+ "--warmup-runs",
478
+ str(args.warmup_runs),
479
+ "--num-runs",
480
+ str(args.num_runs),
481
+ "--log-folder",
482
+ args.log_folder,
483
+ ] + hf_decoder_input_ids_cmd
484
+ logger.info("Benchmark Optimum + ONNX Runtime")
485
+ results = benchmark(args, benchmark_cmd, "optimum-ort", audio_file, duration)
486
+ all_results.extend(results)
487
+
488
+ # Benchmark ONNX Runtime
489
+ if args.ort_model_path:
490
+ benchmark_cmd = (
491
+ [ # noqa: RUF005
492
+ "python",
493
+ "-m",
494
+ "models.whisper.benchmark",
495
+ "--audio-path",
496
+ audio_path,
497
+ "--benchmark-type",
498
+ "ort",
499
+ "--ort-model-path",
500
+ args.ort_model_path,
501
+ "--model-name",
502
+ args.model_name,
503
+ "--precision",
504
+ args.precision,
505
+ "--device",
506
+ args.device,
507
+ "--device-id",
508
+ str(args.device_id),
509
+ "--warmup-runs",
510
+ str(args.warmup_runs),
511
+ "--num-runs",
512
+ str(args.num_runs),
513
+ "--log-folder",
514
+ args.log_folder,
515
+ ]
516
+ + ort_decoder_input_ids_cmd
517
+ + ort_tune_cmd
518
+ )
519
+ logger.info("Benchmark ONNX Runtime")
520
+ results = benchmark(args, benchmark_cmd, "onnxruntime", audio_file, duration)
521
+ all_results.extend(results)
522
+
523
+ csv_file = f"{args.model_size}-{args.precision}_{datetime.datetime.now():%Y-%m-%d_%H:%M:%S}.csv"
524
+ save_results(all_results, os.path.join(args.log_folder, csv_file))
525
+
526
+
527
+ if __name__ == "__main__":
528
+ main()