onnxruntime-directml 1.20.0__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (305) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6508 -0
  4. onnxruntime/__init__.py +78 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +174 -0
  7. onnxruntime/backend/backend_rep.py +53 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  13. onnxruntime/capi/onnxruntime.dll +0 -0
  14. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  15. onnxruntime/capi/onnxruntime_inference_collection.py +1108 -0
  16. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  17. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  18. onnxruntime/capi/onnxruntime_validation.py +150 -0
  19. onnxruntime/capi/version_info.py +2 -0
  20. onnxruntime/datasets/__init__.py +17 -0
  21. onnxruntime/datasets/logreg_iris.onnx +0 -0
  22. onnxruntime/datasets/mul_1.onnx +0 -0
  23. onnxruntime/datasets/sigmoid.onnx +13 -0
  24. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  27. onnxruntime/quantization/__init__.py +16 -0
  28. onnxruntime/quantization/base_quantizer.py +532 -0
  29. onnxruntime/quantization/calibrate.py +1245 -0
  30. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  31. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  32. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  33. onnxruntime/quantization/execution_providers/qnn/preprocess.py +307 -0
  34. onnxruntime/quantization/execution_providers/qnn/quant_config.py +387 -0
  35. onnxruntime/quantization/fusions/__init__.py +3 -0
  36. onnxruntime/quantization/fusions/fusion.py +311 -0
  37. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  38. onnxruntime/quantization/fusions/fusion_layernorm.py +135 -0
  39. onnxruntime/quantization/matmul_4bits_quantizer.py +1480 -0
  40. onnxruntime/quantization/matmul_bnb4_quantizer.py +240 -0
  41. onnxruntime/quantization/onnx_model.py +580 -0
  42. onnxruntime/quantization/onnx_quantizer.py +1008 -0
  43. onnxruntime/quantization/operators/__init__.py +2 -0
  44. onnxruntime/quantization/operators/activation.py +119 -0
  45. onnxruntime/quantization/operators/argmax.py +18 -0
  46. onnxruntime/quantization/operators/attention.py +73 -0
  47. onnxruntime/quantization/operators/base_operator.py +26 -0
  48. onnxruntime/quantization/operators/binary_op.py +72 -0
  49. onnxruntime/quantization/operators/concat.py +62 -0
  50. onnxruntime/quantization/operators/conv.py +258 -0
  51. onnxruntime/quantization/operators/direct_q8.py +78 -0
  52. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  53. onnxruntime/quantization/operators/gather.py +64 -0
  54. onnxruntime/quantization/operators/gavgpool.py +62 -0
  55. onnxruntime/quantization/operators/gemm.py +166 -0
  56. onnxruntime/quantization/operators/lstm.py +117 -0
  57. onnxruntime/quantization/operators/matmul.py +231 -0
  58. onnxruntime/quantization/operators/maxpool.py +34 -0
  59. onnxruntime/quantization/operators/norm.py +40 -0
  60. onnxruntime/quantization/operators/pad.py +100 -0
  61. onnxruntime/quantization/operators/pooling.py +67 -0
  62. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  63. onnxruntime/quantization/operators/resize.py +34 -0
  64. onnxruntime/quantization/operators/softmax.py +74 -0
  65. onnxruntime/quantization/operators/split.py +63 -0
  66. onnxruntime/quantization/operators/where.py +87 -0
  67. onnxruntime/quantization/preprocess.py +141 -0
  68. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  69. onnxruntime/quantization/qdq_quantizer.py +1187 -0
  70. onnxruntime/quantization/quant_utils.py +891 -0
  71. onnxruntime/quantization/quantize.py +748 -0
  72. onnxruntime/quantization/registry.py +106 -0
  73. onnxruntime/quantization/shape_inference.py +187 -0
  74. onnxruntime/quantization/tensor_quant_overrides.py +516 -0
  75. onnxruntime/tools/__init__.py +10 -0
  76. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  77. onnxruntime/tools/convert_onnx_models_to_ort.py +377 -0
  78. onnxruntime/tools/file_utils.py +46 -0
  79. onnxruntime/tools/logger.py +11 -0
  80. onnxruntime/tools/make_dynamic_shape_fixed.py +72 -0
  81. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  82. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +33 -0
  83. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  84. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  85. onnxruntime/tools/mobile_helpers/usability_checker.py +739 -0
  86. onnxruntime/tools/offline_tuning.py +169 -0
  87. onnxruntime/tools/onnx_model_utils.py +413 -0
  88. onnxruntime/tools/onnx_randomizer.py +85 -0
  89. onnxruntime/tools/onnxruntime_test.py +164 -0
  90. onnxruntime/tools/optimize_onnx_model.py +55 -0
  91. onnxruntime/tools/ort_format_model/__init__.py +25 -0
  92. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +663 -0
  93. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  94. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  95. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  96. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  97. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  98. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  99. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  100. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  140. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  141. onnxruntime/tools/ort_format_model/types.py +84 -0
  142. onnxruntime/tools/ort_format_model/utils.py +62 -0
  143. onnxruntime/tools/pytorch_export_contrib_ops.py +108 -0
  144. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  145. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  146. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  147. onnxruntime/tools/reduced_build_config_parser.py +202 -0
  148. onnxruntime/tools/symbolic_shape_infer.py +3016 -0
  149. onnxruntime/tools/update_onnx_opset.py +31 -0
  150. onnxruntime/transformers/__init__.py +8 -0
  151. onnxruntime/transformers/affinity_helper.py +40 -0
  152. onnxruntime/transformers/benchmark.py +944 -0
  153. onnxruntime/transformers/benchmark_helper.py +646 -0
  154. onnxruntime/transformers/bert_perf_test.py +634 -0
  155. onnxruntime/transformers/bert_test_data.py +642 -0
  156. onnxruntime/transformers/compare_bert_results.py +246 -0
  157. onnxruntime/transformers/constants.py +47 -0
  158. onnxruntime/transformers/convert_generation.py +3124 -0
  159. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  160. onnxruntime/transformers/convert_to_packing_mode.py +387 -0
  161. onnxruntime/transformers/dynamo_onnx_helper.py +104 -0
  162. onnxruntime/transformers/float16.py +501 -0
  163. onnxruntime/transformers/fusion_attention.py +1235 -0
  164. onnxruntime/transformers/fusion_attention_clip.py +257 -0
  165. onnxruntime/transformers/fusion_attention_sam2.py +534 -0
  166. onnxruntime/transformers/fusion_attention_unet.py +1304 -0
  167. onnxruntime/transformers/fusion_attention_vae.py +301 -0
  168. onnxruntime/transformers/fusion_bart_attention.py +640 -0
  169. onnxruntime/transformers/fusion_base.py +137 -0
  170. onnxruntime/transformers/fusion_bias_add.py +58 -0
  171. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  172. onnxruntime/transformers/fusion_biassplitgelu.py +111 -0
  173. onnxruntime/transformers/fusion_conformer_attention.py +143 -0
  174. onnxruntime/transformers/fusion_embedlayer.py +811 -0
  175. onnxruntime/transformers/fusion_fastgelu.py +360 -0
  176. onnxruntime/transformers/fusion_gelu.py +259 -0
  177. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  178. onnxruntime/transformers/fusion_gemmfastgelu.py +122 -0
  179. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  180. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  181. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  182. onnxruntime/transformers/fusion_group_norm.py +179 -0
  183. onnxruntime/transformers/fusion_layernorm.py +465 -0
  184. onnxruntime/transformers/fusion_nhwc_conv.py +100 -0
  185. onnxruntime/transformers/fusion_options.py +340 -0
  186. onnxruntime/transformers/fusion_qordered_attention.py +421 -0
  187. onnxruntime/transformers/fusion_qordered_gelu.py +119 -0
  188. onnxruntime/transformers/fusion_qordered_layernorm.py +123 -0
  189. onnxruntime/transformers/fusion_qordered_matmul.py +217 -0
  190. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  191. onnxruntime/transformers/fusion_reshape.py +173 -0
  192. onnxruntime/transformers/fusion_rotary_attention.py +1592 -0
  193. onnxruntime/transformers/fusion_shape.py +110 -0
  194. onnxruntime/transformers/fusion_simplified_layernorm.py +159 -0
  195. onnxruntime/transformers/fusion_skip_group_norm.py +255 -0
  196. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  197. onnxruntime/transformers/fusion_transpose.py +168 -0
  198. onnxruntime/transformers/fusion_utils.py +307 -0
  199. onnxruntime/transformers/huggingface_models.py +167 -0
  200. onnxruntime/transformers/import_utils.py +20 -0
  201. onnxruntime/transformers/io_binding_helper.py +442 -0
  202. onnxruntime/transformers/large_model_exporter.py +395 -0
  203. onnxruntime/transformers/machine_info.py +221 -0
  204. onnxruntime/transformers/metrics.py +164 -0
  205. onnxruntime/transformers/models/bart/__init__.py +12 -0
  206. onnxruntime/transformers/models/bart/export.py +98 -0
  207. onnxruntime/transformers/models/bert/__init__.py +12 -0
  208. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  209. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  210. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  211. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +561 -0
  212. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1032 -0
  213. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  214. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  215. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  216. onnxruntime/transformers/models/llama/__init__.py +12 -0
  217. onnxruntime/transformers/models/llama/benchmark.py +703 -0
  218. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  219. onnxruntime/transformers/models/llama/benchmark_e2e.py +606 -0
  220. onnxruntime/transformers/models/llama/convert_to_onnx.py +1027 -0
  221. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  222. onnxruntime/transformers/models/llama/llama_inputs.py +503 -0
  223. onnxruntime/transformers/models/llama/llama_parity.py +309 -0
  224. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  225. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  226. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  227. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  228. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  229. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  230. onnxruntime/transformers/models/longformer/longformer_helper.py +77 -0
  231. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  232. onnxruntime/transformers/models/phi2/convert_to_onnx.py +576 -0
  233. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  234. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  235. onnxruntime/transformers/models/sam2/benchmark_sam2.py +625 -0
  236. onnxruntime/transformers/models/sam2/convert_to_onnx.py +260 -0
  237. onnxruntime/transformers/models/sam2/image_decoder.py +273 -0
  238. onnxruntime/transformers/models/sam2/image_encoder.py +186 -0
  239. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  240. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  241. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  242. onnxruntime/transformers/models/sam2/sam2_demo.py +322 -0
  243. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +280 -0
  244. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  245. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  246. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1429 -0
  247. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  248. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +102 -0
  249. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +268 -0
  250. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  251. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1319 -0
  252. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1181 -0
  253. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +296 -0
  254. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +388 -0
  255. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  256. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  257. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  258. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +350 -0
  259. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  260. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  261. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  262. onnxruntime/transformers/models/t5/__init__.py +12 -0
  263. onnxruntime/transformers/models/t5/convert_to_onnx.py +278 -0
  264. onnxruntime/transformers/models/t5/past_helper.py +150 -0
  265. onnxruntime/transformers/models/t5/t5_decoder.py +438 -0
  266. onnxruntime/transformers/models/t5/t5_encoder.py +171 -0
  267. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +299 -0
  268. onnxruntime/transformers/models/t5/t5_helper.py +272 -0
  269. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  270. onnxruntime/transformers/models/whisper/benchmark.py +610 -0
  271. onnxruntime/transformers/models/whisper/benchmark_all.py +528 -0
  272. onnxruntime/transformers/models/whisper/convert_to_onnx.py +536 -0
  273. onnxruntime/transformers/models/whisper/whisper_chain.py +329 -0
  274. onnxruntime/transformers/models/whisper/whisper_decoder.py +402 -0
  275. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  276. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +306 -0
  277. onnxruntime/transformers/models/whisper/whisper_helper.py +524 -0
  278. onnxruntime/transformers/models/whisper/whisper_openai_helper.py +84 -0
  279. onnxruntime/transformers/onnx_exporter.py +717 -0
  280. onnxruntime/transformers/onnx_model.py +1569 -0
  281. onnxruntime/transformers/onnx_model_bart.py +142 -0
  282. onnxruntime/transformers/onnx_model_bert.py +481 -0
  283. onnxruntime/transformers/onnx_model_bert_keras.py +475 -0
  284. onnxruntime/transformers/onnx_model_bert_tf.py +589 -0
  285. onnxruntime/transformers/onnx_model_clip.py +40 -0
  286. onnxruntime/transformers/onnx_model_conformer.py +33 -0
  287. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  288. onnxruntime/transformers/onnx_model_phi.py +930 -0
  289. onnxruntime/transformers/onnx_model_sam2.py +138 -0
  290. onnxruntime/transformers/onnx_model_t5.py +791 -0
  291. onnxruntime/transformers/onnx_model_tnlr.py +227 -0
  292. onnxruntime/transformers/onnx_model_unet.py +259 -0
  293. onnxruntime/transformers/onnx_model_vae.py +43 -0
  294. onnxruntime/transformers/onnx_utils.py +55 -0
  295. onnxruntime/transformers/optimizer.py +612 -0
  296. onnxruntime/transformers/profiler.py +725 -0
  297. onnxruntime/transformers/quantize_helper.py +76 -0
  298. onnxruntime/transformers/shape_infer_helper.py +122 -0
  299. onnxruntime/transformers/shape_optimizer.py +401 -0
  300. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  301. onnxruntime_directml-1.20.0.dist-info/METADATA +187 -0
  302. onnxruntime_directml-1.20.0.dist-info/RECORD +305 -0
  303. onnxruntime_directml-1.20.0.dist-info/WHEEL +5 -0
  304. onnxruntime_directml-1.20.0.dist-info/entry_points.txt +2 -0
  305. onnxruntime_directml-1.20.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,610 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License. See License.txt in the project root for
4
+ # license information.
5
+ # --------------------------------------------------------------------------
6
+
7
+ import argparse
8
+ import ast
9
+ import datetime
10
+ import gc
11
+ import logging
12
+ import os
13
+ import sys
14
+ import time
15
+
16
+ import numpy as np
17
+ import psutil
18
+ import torch
19
+ import whisper
20
+ from benchmark_helper import measure_memory, setup_logger
21
+ from onnxruntime_extensions import get_library_path
22
+ from optimum.onnxruntime import ORTModelForSpeechSeq2Seq
23
+ from torch.profiler import ProfilerActivity, profile, record_function
24
+ from tqdm import trange
25
+ from transformers import AutoModelForSpeechSeq2Seq, WhisperConfig, WhisperProcessor
26
+
27
+ import onnxruntime as ort
28
+
29
+ logger = logging.getLogger(__name__)
30
+
31
+
32
+ def get_inputs(args: argparse.Namespace):
33
+ if args.benchmark_type not in {"hf-pt-eager", "hf-pt-compile", "hf-ort", "ort"}:
34
+ raise Exception("Unable to auto-detect inputs for provided model")
35
+
36
+ def load_via_ffmpeg():
37
+ audio = whisper.load_audio(args.audio_path)
38
+ audio = whisper.pad_or_trim(audio)
39
+ return audio
40
+
41
+ def load_via_numpy():
42
+ with open(args.audio_path, "rb") as f:
43
+ audio = np.asarray(list(f.read()), dtype=np.uint8)
44
+ audio = np.array([audio])
45
+ return audio
46
+
47
+ inputs = {
48
+ "max_length": args.max_length,
49
+ "min_length": args.min_length,
50
+ "num_beams": args.num_beams,
51
+ "num_return_sequences": args.num_return_sequences,
52
+ "length_penalty": args.length_penalty,
53
+ "repetition_penalty": args.repetition_penalty,
54
+ }
55
+ if args.benchmark_type == "ort":
56
+ # convert_to_onnx export or ONNX E2E solution created by Olive
57
+ for k, v in inputs.items():
58
+ inputs[k] = np.array([v], dtype=np.float32 if "penalty" in k else np.int32)
59
+ if args.has_decoder_input_ids:
60
+ inputs["decoder_input_ids"] = np.array([args.decoder_input_ids], dtype=np.int32)
61
+ if args.has_logits_processor:
62
+ inputs["logits_processor"] = np.array([args.logits_processor], dtype=np.int32)
63
+ if args.has_temperature:
64
+ inputs["temperature"] = np.array([args.temperature], dtype=np.float32)
65
+
66
+ # Measure time taken to load audio file
67
+ logger.info(f"Load audio: {args.audio_path}")
68
+ load_audio_fn = lambda onnx_e2e: load_via_numpy() if onnx_e2e else load_via_ffmpeg() # noqa: E731
69
+ time_fn(args, load_audio_fn, args.has_audio_stream)
70
+ audio_data = load_audio_fn(args.has_audio_stream)
71
+
72
+ if args.has_audio_stream:
73
+ # ONNX E2E solution created by Olive
74
+ inputs["audio_stream"] = audio_data
75
+ return inputs
76
+
77
+ # Measure time taken to get input features
78
+ logger.info("Feature extraction: ")
79
+ return_type = "np" if args.benchmark_type == "ort" else "pt"
80
+ processor_fn = lambda audio: args.processor.feature_extractor( # noqa: E731
81
+ [audio], return_tensors=return_type, sampling_rate=args.sampling_rate
82
+ ).input_features
83
+ time_fn(args, processor_fn, audio_data)
84
+ input_features = processor_fn(audio_data)
85
+
86
+ if args.benchmark_type == "ort":
87
+ # convert_to_onnx export
88
+ inputs["input_features"] = input_features
89
+ return inputs
90
+
91
+ inputs["inputs"] = input_features.to(
92
+ dtype=torch.float16 if args.use_fp16 else torch.float32, device=args.target_device
93
+ )
94
+ inputs["no_repeat_ngram_size"] = args.no_repeat_ngram_size
95
+ inputs["early_stopping"] = True
96
+ inputs["use_cache"] = True
97
+
98
+ if args.decoder_input_ids:
99
+ inputs["forced_decoder_ids"] = args.decoder_input_ids
100
+
101
+ return inputs
102
+
103
+
104
+ def get_model(args: argparse.Namespace):
105
+ model, sess_options = None, None
106
+ start_time, end_time = None, None
107
+
108
+ # There are multiple sources that the model could come from:
109
+ # 1) Benchmark Whisper from Hugging Face
110
+ # 2) Benchmark Whisper ONNX model from Optimum export (without pre/post processing)
111
+ # 3) Benchmark Whisper ONNX E2E model from Olive (with pre/post processing)
112
+
113
+ if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile"}:
114
+ source = args.hf_pt_model_path if args.hf_pt_model_path else args.model_name
115
+ start_time = time.time()
116
+ model = AutoModelForSpeechSeq2Seq.from_pretrained(
117
+ source,
118
+ torch_dtype=torch.float16 if args.use_fp16 else torch.float32,
119
+ use_cache=True,
120
+ ).to(args.target_device)
121
+ end_time = time.time()
122
+
123
+ if args.benchmark_type == "hf-pt-compile":
124
+ model = torch.compile(model)
125
+
126
+ elif args.benchmark_type in {"hf-ort", "ort"}:
127
+ sess_options = ort.SessionOptions()
128
+ sess_options.enable_profiling = args.profile
129
+ sess_options.register_custom_ops_library(get_library_path())
130
+ if args.verbose:
131
+ sess_options.log_verbosity_level = 1
132
+ sess_options.log_severity_level = 1
133
+ if args.tune:
134
+ ort.set_default_logger_severity(0)
135
+ ort.set_default_logger_verbosity(0)
136
+
137
+ else:
138
+ raise Exception(f"Cannot recognize {args.benchmark_type}")
139
+
140
+ if args.benchmark_type == "hf-ort":
141
+ # Optimum export
142
+ provider = args.execution_provider[0] if type(args.execution_provider) is tuple else args.execution_provider
143
+ provider_options = args.execution_provider[1] if type(args.execution_provider) is tuple else None
144
+
145
+ start_time = time.time()
146
+ model = ORTModelForSpeechSeq2Seq.from_pretrained(
147
+ args.hf_ort_dir_path,
148
+ provider=provider,
149
+ provider_options=provider_options,
150
+ session_options=sess_options,
151
+ use_io_binding=True, # Avoid memory copy overhead
152
+ )
153
+ end_time = time.time()
154
+
155
+ if args.benchmark_type == "ort":
156
+ # convert_to_onnx.py export
157
+ logger.info(f"Loading model from {args.ort_model_path}")
158
+ start_time = time.time()
159
+ model = ort.InferenceSession(
160
+ args.ort_model_path,
161
+ sess_options,
162
+ providers=[args.execution_provider],
163
+ )
164
+ end_time = time.time()
165
+
166
+ logger.info(f"Loaded model in {end_time - start_time} s")
167
+
168
+ return model
169
+
170
+
171
+ def time_fn(args, fn, inputs):
172
+ warmup_inputs = inputs[0] if type(inputs) is tuple else inputs
173
+ benchmark_inputs = inputs[1] if type(inputs) is tuple else inputs
174
+ torch_device = torch.device(args.target_device)
175
+
176
+ # Warm up
177
+ warmup_range = (
178
+ range(args.warmup_runs)
179
+ if args.benchmark_type == "ort"
180
+ else trange(args.warmup_runs, file=sys.stdout, desc="Warm up")
181
+ )
182
+
183
+ if args.verbose:
184
+ outputs = fn(warmup_inputs)
185
+ logger.info(outputs)
186
+
187
+ for _ in warmup_range:
188
+ fn(warmup_inputs)
189
+
190
+ # Benchmark
191
+ if args.device != "cpu":
192
+ torch.cuda.synchronize(torch_device)
193
+ start_time = time.time()
194
+
195
+ bench_range = (
196
+ range(args.num_runs)
197
+ if args.benchmark_type == "ort"
198
+ else trange(args.num_runs, file=sys.stdout, desc="Benchmark")
199
+ )
200
+ for _ in bench_range:
201
+ fn(benchmark_inputs)
202
+
203
+ if args.device != "cpu":
204
+ torch.cuda.synchronize(torch_device)
205
+ end_time = time.time()
206
+
207
+ # Newline print after trange in order to print metrics on new lines without progress bar on same line
208
+ if args.benchmark_type != "ort":
209
+ logger.info("")
210
+
211
+ batch_size = 1
212
+ latency = (end_time - start_time) / args.num_runs
213
+ throughput = batch_size / latency
214
+
215
+ logger.info(f"Latency: {latency} s")
216
+ logger.info(f"Throughput: {throughput} qps")
217
+ return
218
+
219
+
220
+ def profile_fn(args, fn, inputs, inputs_type):
221
+ # Filename prefix format:
222
+ # "<benchmark-type>-<precision>-<device>_<inference-step>_<inputs-type>_<current-time>"
223
+ prefix = f"{args.benchmark_type.lower()}-{args.precision}-{args.device}_{fn.__name__.replace('_', '-')}_{inputs_type}_{datetime.datetime.now():%Y-%m-%d_%H:%M:%S}"
224
+ filename = None
225
+
226
+ if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile"}:
227
+ # Profile PyTorch kernels
228
+ with profile( # noqa: SIM117
229
+ activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True, profile_memory=True
230
+ ) as prof:
231
+ with record_function("model_inference"):
232
+ fn(inputs)
233
+ prof_data = prof.key_averages(group_by_stack_n=5).table(sort_by=args.pt_filter_by, row_limit=args.pt_num_rows)
234
+
235
+ filename = os.path.join(args.log_folder, f"{prefix}.log")
236
+ with open(filename, "w") as f:
237
+ f.write(prof_data)
238
+
239
+ else:
240
+ # Profile ORT kernels
241
+ fn(inputs)
242
+
243
+ # Set new log name for ORT profile log generated
244
+ filename = f"{prefix}.json"
245
+
246
+ return filename
247
+
248
+
249
+ def measure_fn(args, fn, inputs):
250
+ # Measure CPU usage
251
+ pid = os.getpid()
252
+ process = psutil.Process(pid)
253
+ process.cpu_percent(interval=0.1)
254
+
255
+ fn(inputs)
256
+ logger.info(f"CPU usage: {process.cpu_percent(interval=None)}%")
257
+
258
+ # Measure memory usage
259
+ gc.collect()
260
+ torch.cuda.empty_cache()
261
+ measure_memory(is_gpu=(args.device != "cpu"), func=lambda: fn(inputs), monitor_type=args.monitor_type)
262
+
263
+ # Flush output so memory usage is printed
264
+ sys.stdout.flush()
265
+
266
+
267
+ def run_hf_inference(args, inputs, model):
268
+ # Inference steps to measure
269
+ def get_pred_ids(inputs):
270
+ # Inference pass with predicted token ids generation
271
+ predicted_ids = model.generate(**inputs)
272
+ return predicted_ids
273
+
274
+ def gen_and_dec(inputs):
275
+ # Inference pass with generation and decoding
276
+ predicted_ids = get_pred_ids(inputs)
277
+ transcription = []
278
+ for _ in range(args.num_return_sequences):
279
+ transcription.append(args.processor.batch_decode(predicted_ids, skip_special_tokens=True)[0])
280
+ return predicted_ids, transcription
281
+
282
+ # Examples of other inference steps that can be measured:
283
+ # To use, uncomment the function and assign it to `generate_fn`
284
+
285
+ # def get_logits(inputs):
286
+ # # Inference pass without decoding
287
+ # outputs = model(**inputs)
288
+ # return outputs
289
+
290
+ generate_fn = gen_and_dec
291
+
292
+ if args.benchmark_type == "hf-pt-compile":
293
+ # Run forward pass once with each set of inputs to process through Dynamo
294
+ generate_fn(inputs)
295
+
296
+ if args.profile:
297
+ new_logname = profile_fn(args, generate_fn, inputs, "gen-and-dec")
298
+ if args.benchmark_type == "hf-ort":
299
+ # Rename log files per model component and turn profiling off to stop appending to log
300
+ new_prefix = new_logname[: -len(".json")]
301
+
302
+ old_logname = model.encoder.session.end_profiling()
303
+ new_logname = new_prefix + "-encoder.json"
304
+ if os.path.isfile(old_logname):
305
+ logger.warning(f"Renaming {old_logname} to {new_logname}")
306
+ os.rename(old_logname, os.path.join(args.log_folder, new_logname))
307
+
308
+ old_logname = model.decoder.session.end_profiling()
309
+ new_logname = new_prefix + "-decoder.json"
310
+ if os.path.isfile(old_logname):
311
+ logger.warning(f"Renaming {old_logname} to {new_logname}")
312
+ os.rename(old_logname, os.path.join(args.log_folder, new_logname))
313
+
314
+ old_logname = model.decoder_with_past.session.end_profiling()
315
+ new_logname = new_prefix + "-decoder-with-past.json"
316
+ if os.path.isfile(old_logname):
317
+ logger.warning(f"Renaming {old_logname} to {new_logname}")
318
+ os.rename(old_logname, os.path.join(args.log_folder, new_logname))
319
+
320
+ return
321
+
322
+ # PyTorch evaluations
323
+ logger.info("\nEvaluating PyTorch...")
324
+ time_fn(args, generate_fn, inputs)
325
+ predicted_ids, transcription = generate_fn(inputs)
326
+ logger.info(f"Generated token length: {len(predicted_ids[0])} tokens")
327
+ logger.info(f"Transcription: {transcription[0]}")
328
+ measure_fn(args, generate_fn, inputs)
329
+
330
+
331
+ def run_ort_inference(args, inputs, model):
332
+ def prepare_ort_inputs(inputs, warmup=False):
333
+ # Check that all model inputs will be provided
334
+ model_inputs = set(map(lambda model_input: model_input.name, model.get_inputs()))
335
+ user_inputs = set(inputs.keys())
336
+ missing_inputs = model_inputs - user_inputs
337
+ if len(missing_inputs):
338
+ logger.error(f"The following model inputs are missing: {missing_inputs}")
339
+ raise Exception("There are missing inputs to the model. Please add them and try again.")
340
+
341
+ if warmup and args.tune:
342
+ inputs["min_length"] = inputs["max_length"]
343
+
344
+ # Remove unnecessary inputs from model inputs
345
+ unnecessary_inputs = user_inputs - model_inputs
346
+ if len(unnecessary_inputs):
347
+ for unnecessary_input in unnecessary_inputs:
348
+ logger.info(f"Removing unnecessary input '{unnecessary_input}' from user provided inputs")
349
+ del inputs[unnecessary_input]
350
+
351
+ # Add IO bindings for non-CPU execution providers
352
+ if args.device != "cpu":
353
+ io_binding = model.io_binding()
354
+ for k, v in inputs.items():
355
+ io_binding.bind_cpu_input(k, v)
356
+ for output in model.get_outputs():
357
+ io_binding.bind_output(output.name, device_type=args.device, device_id=args.device_id)
358
+ return io_binding
359
+
360
+ return inputs
361
+
362
+ def with_io_binding(io_binding):
363
+ # Inference pass with IO binding
364
+ model.run_with_iobinding(io_binding)
365
+ return io_binding
366
+
367
+ def without_io_binding(inputs):
368
+ # Inference pass without IO binding
369
+ outputs = model.run(None, inputs)
370
+ return outputs
371
+
372
+ def handle_output(output):
373
+ if args.eos_token_id in output:
374
+ first_end = np.where(output == args.eos_token_id)[0][0]
375
+ return output[: first_end + 1]
376
+
377
+ return output
378
+
379
+ generate_fn = with_io_binding if args.device != "cpu" else without_io_binding
380
+ ort_inputs = prepare_ort_inputs(inputs)
381
+
382
+ if args.profile:
383
+ new_logname = profile_fn(args, generate_fn, ort_inputs, "e2e")
384
+
385
+ # Turn profiling off to stop appending to log file
386
+ old_logname = model.end_profiling()
387
+ logger.warning(f"Renaming {old_logname} to {new_logname}")
388
+ os.rename(old_logname, os.path.join(args.log_folder, new_logname))
389
+
390
+ return
391
+
392
+ # ORT evaluation
393
+ logger.info("\nEvaluating ONNX Runtime...")
394
+ ort_evaluate_inputs = ort_inputs
395
+ if args.tune:
396
+ ort_warmup_inputs = prepare_ort_inputs(inputs, warmup=True)
397
+ ort_evaluate_inputs = (ort_warmup_inputs, ort_inputs)
398
+
399
+ time_fn(args, generate_fn, ort_evaluate_inputs)
400
+ ort_outputs = generate_fn(ort_inputs)
401
+ if args.device != "cpu":
402
+ ort_outputs = ort_outputs.copy_outputs_to_cpu()
403
+ ort_outputs = ort_outputs[0]
404
+
405
+ if args.has_audio_stream:
406
+ # ONNX E2E model from Olive produces transcribed output
407
+ logger.info(f"Transcription: {ort_outputs[0][0]}")
408
+ else:
409
+ # convert_to_onnx model produces generated ids
410
+ actual_output = handle_output(ort_outputs[0][0])
411
+ logger.info(f"Generated token length: {len(actual_output)} tokens")
412
+ transcription = args.processor.batch_decode(ort_outputs[0], skip_special_tokens=True)[0]
413
+ # print to stdout as the output for comparison
414
+ print(f"{transcription}")
415
+
416
+ measure_fn(args, generate_fn, ort_inputs)
417
+
418
+
419
+ def run_inference(args, inputs, model):
420
+ if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile", "hf-ort"}:
421
+ run_hf_inference(args, inputs, model)
422
+ elif args.benchmark_type == "ort":
423
+ run_ort_inference(args, inputs, model)
424
+ else:
425
+ raise Exception(f"Cannot recognize {args.benchmark_type}")
426
+
427
+
428
+ def parse_args():
429
+ parser = argparse.ArgumentParser()
430
+
431
+ parser.add_argument(
432
+ "-bt",
433
+ "--benchmark-type",
434
+ type=str,
435
+ required=True,
436
+ choices=["hf-pt-eager", "hf-pt-compile", "hf-ort", "ort"],
437
+ )
438
+
439
+ parser.add_argument(
440
+ "-m",
441
+ "--model-name",
442
+ type=str,
443
+ required=True,
444
+ help="Hugging Face name of model (e.g. 'openai/whisper-large-v2')",
445
+ )
446
+ parser.add_argument(
447
+ "-p",
448
+ "--precision",
449
+ type=str,
450
+ required=True,
451
+ default="fp32",
452
+ choices=["int8", "fp16", "fp32"],
453
+ help="Precision for model. For ONNX models, the model's precision should be set before running this script.",
454
+ )
455
+
456
+ parser.add_argument(
457
+ "--hf-pt-model-path",
458
+ type=str,
459
+ default="",
460
+ help="Path to directory containing all PyTorch files (e.g. tokenizer, PyTorch model)",
461
+ )
462
+ parser.add_argument(
463
+ "--hf-ort-dir-path",
464
+ type=str,
465
+ default="",
466
+ help="Path to directory containing all ONNX files (e.g. tokenizer, encoder, decoder, decoder_with_past)",
467
+ )
468
+ parser.add_argument(
469
+ "--ort-model-path",
470
+ type=str,
471
+ default="",
472
+ help="Path to ONNX model",
473
+ )
474
+
475
+ # Args for running and evaluating the model
476
+ parser.add_argument("-a", "--audio-path", type=str, required=True, help="Path to audio file for E2E evaluation")
477
+ parser.add_argument(
478
+ "-d",
479
+ "--device",
480
+ type=str,
481
+ default="cuda" if torch.cuda.is_available() else "cpu",
482
+ choices=["cpu", "cuda", "rocm"],
483
+ )
484
+ parser.add_argument("-id", "--device-id", type=int, default=0)
485
+ parser.add_argument("-w", "--warmup-runs", type=int, default=5)
486
+ parser.add_argument("-n", "--num-runs", type=int, default=10)
487
+ parser.add_argument("--seed", type=int, default=2)
488
+
489
+ # Optional args:
490
+ parser.add_argument("--sampling-rate", type=int, default=16000, help="Sampling rate for audio (in Hz)")
491
+
492
+ # Args for decoding logic
493
+ # Required args:
494
+ parser.add_argument("--max-length", type=int, default=448)
495
+ parser.add_argument("--min-length", type=int, default=0)
496
+ parser.add_argument("--num-beams", type=int, default=1)
497
+ parser.add_argument("--num-return-sequences", type=int, default=1)
498
+ parser.add_argument("--length-penalty", type=float, default=1.0)
499
+ parser.add_argument("--repetition-penalty", type=float, default=1.0)
500
+ parser.add_argument("--no-repeat-ngram-size", type=int, default=3)
501
+
502
+ # Optional args for E2E solution:
503
+ parser.add_argument(
504
+ "--decoder-input-ids",
505
+ type=str,
506
+ default="[]",
507
+ help="The forced decoder ids for generation. Format is [start token, timestamp token, language token, task token]. Default is [start token]. See `decoder_input_ids` in https://github.com/microsoft/Olive/tree/main/examples/whisper for details.",
508
+ )
509
+ parser.add_argument(
510
+ "--logits-processor",
511
+ type=int,
512
+ default=1,
513
+ help="Whether to use timestamps logits processor or not (0 for false, 1 for true).",
514
+ )
515
+ parser.add_argument(
516
+ "--temperature",
517
+ type=float,
518
+ default=1.0,
519
+ help="Temperature value for generation.",
520
+ )
521
+
522
+ # Args for accessing detailed info
523
+ parser.add_argument("--profile", default=False, action="store_true")
524
+ parser.add_argument(
525
+ "--pt-filter-by", type=str, default="self_cpu_time_total", help="What to filter PyTorch profiler by"
526
+ )
527
+ parser.add_argument("--pt-num-rows", type=int, default=1000, help="Number of rows for PyTorch profiler to display")
528
+ parser.add_argument("--verbose", default=False, action="store_true")
529
+ parser.add_argument("--log-folder", type=str, default=os.path.join("."), help="Folder to cache log files")
530
+ parser.add_argument(
531
+ "--tune",
532
+ default=False,
533
+ action="store_true",
534
+ help="Only used by ROCm EP, enable TunableOp tuning to select fastest kernel",
535
+ )
536
+
537
+ args = parser.parse_args()
538
+
539
+ # Set seed properties
540
+ np.random.seed(args.seed)
541
+ torch.manual_seed(args.seed)
542
+
543
+ args.monitor_type = args.device
544
+ # Set runtime properties
545
+ if "ort" in args.benchmark_type:
546
+ args.execution_provider = f"{args.device.upper()}ExecutionProvider"
547
+ if args.execution_provider == "CUDAExecutionProvider":
548
+ args.execution_provider = (args.execution_provider, {"device_id": args.device_id})
549
+ elif args.execution_provider == "ROCMExecutionProvider":
550
+ args.execution_provider = (
551
+ args.execution_provider,
552
+ {
553
+ "device_id": args.device_id,
554
+ "tunable_op_enable": 1,
555
+ "tunable_op_tuning_enable": 1 if args.tune else 0,
556
+ },
557
+ )
558
+ args.device = "cuda"
559
+
560
+ # Check that model paths have been specified for any benchmarking with ORT
561
+ if args.benchmark_type == "hf-ort":
562
+ assert args.hf_ort_dir_path, "Please specify a path to `--hf-ort-dir-path`"
563
+ if args.benchmark_type == "ort":
564
+ assert args.ort_model_path, "Please specify a path to `--ort-model-path`"
565
+
566
+ # Convert decoder_input_ids string to list of ids
567
+ # (e.g. "[1, 50257]" for Hugging Face or "[50257]" for ORT)
568
+ args.decoder_input_ids = ast.literal_eval(args.decoder_input_ids)
569
+
570
+ return args
571
+
572
+
573
+ def main():
574
+ args = parse_args()
575
+ setup_logger(args.verbose)
576
+ logger.info(args.__dict__)
577
+ torch.backends.cudnn.benchmark = True
578
+
579
+ config = WhisperConfig.from_pretrained(args.model_name)
580
+ processor = WhisperProcessor.from_pretrained(args.model_name)
581
+ target_device = f"cuda:{args.device_id}" if args.device != "cpu" else args.device
582
+ use_fp16 = args.precision == "fp16"
583
+
584
+ setattr(args, "processor", processor) # noqa: B010
585
+ setattr(args, "target_device", target_device) # noqa: B010
586
+ setattr(args, "use_fp16", use_fp16) # noqa: B010
587
+ setattr(args, "has_audio_stream", False) # noqa: B010
588
+ setattr(args, "eos_token_id", config.eos_token_id) # noqa: B010
589
+
590
+ logger.info(f"Forced decoder prompt ids: {args.decoder_input_ids}")
591
+
592
+ # Measure cost to transcribe audio
593
+ model = get_model(args)
594
+ if args.benchmark_type == "ort":
595
+ # Check for optional inputs that could have been added during export
596
+ ort_model_inputs = set(map(lambda model_input: model_input.name, model.get_inputs()))
597
+ args.has_audio_stream = "audio_stream" in ort_model_inputs
598
+ setattr(args, "has_decoder_input_ids", "decoder_input_ids" in ort_model_inputs) # noqa: B010
599
+ setattr(args, "has_logits_processor", "logits_processor" in ort_model_inputs) # noqa: B010
600
+ setattr(args, "has_temperature", "temperature" in ort_model_inputs) # noqa: B010
601
+
602
+ if args.decoder_input_ids == []:
603
+ args.decoder_input_ids = [config.decoder_start_token_id]
604
+
605
+ inputs = get_inputs(args)
606
+ run_inference(args, inputs, model)
607
+
608
+
609
+ if __name__ == "__main__":
610
+ main()