onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (322) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6121 -0
  4. onnxruntime/__init__.py +418 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +175 -0
  7. onnxruntime/backend/backend_rep.py +52 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/build_and_package_info.py +2 -0
  13. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  14. onnxruntime/capi/onnxruntime.dll +0 -0
  15. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  16. onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
  17. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  18. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  19. onnxruntime/capi/onnxruntime_validation.py +154 -0
  20. onnxruntime/capi/version_info.py +2 -0
  21. onnxruntime/datasets/__init__.py +18 -0
  22. onnxruntime/datasets/logreg_iris.onnx +0 -0
  23. onnxruntime/datasets/mul_1.onnx +0 -0
  24. onnxruntime/datasets/sigmoid.onnx +13 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  27. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  28. onnxruntime/quantization/__init__.py +19 -0
  29. onnxruntime/quantization/base_quantizer.py +529 -0
  30. onnxruntime/quantization/calibrate.py +1267 -0
  31. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  32. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  33. onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
  34. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  35. onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
  36. onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
  37. onnxruntime/quantization/fusions/__init__.py +4 -0
  38. onnxruntime/quantization/fusions/fusion.py +311 -0
  39. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  40. onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
  41. onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
  42. onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
  43. onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
  44. onnxruntime/quantization/neural_compressor/__init__.py +1 -0
  45. onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
  46. onnxruntime/quantization/neural_compressor/util.py +80 -0
  47. onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
  48. onnxruntime/quantization/onnx_model.py +600 -0
  49. onnxruntime/quantization/onnx_quantizer.py +1163 -0
  50. onnxruntime/quantization/operators/__init__.py +2 -0
  51. onnxruntime/quantization/operators/activation.py +119 -0
  52. onnxruntime/quantization/operators/argmax.py +18 -0
  53. onnxruntime/quantization/operators/attention.py +73 -0
  54. onnxruntime/quantization/operators/base_operator.py +26 -0
  55. onnxruntime/quantization/operators/binary_op.py +72 -0
  56. onnxruntime/quantization/operators/concat.py +62 -0
  57. onnxruntime/quantization/operators/conv.py +260 -0
  58. onnxruntime/quantization/operators/direct_q8.py +78 -0
  59. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  60. onnxruntime/quantization/operators/gather.py +64 -0
  61. onnxruntime/quantization/operators/gavgpool.py +62 -0
  62. onnxruntime/quantization/operators/gemm.py +172 -0
  63. onnxruntime/quantization/operators/lstm.py +121 -0
  64. onnxruntime/quantization/operators/matmul.py +231 -0
  65. onnxruntime/quantization/operators/maxpool.py +34 -0
  66. onnxruntime/quantization/operators/norm.py +40 -0
  67. onnxruntime/quantization/operators/pad.py +172 -0
  68. onnxruntime/quantization/operators/pooling.py +67 -0
  69. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  70. onnxruntime/quantization/operators/resize.py +34 -0
  71. onnxruntime/quantization/operators/softmax.py +74 -0
  72. onnxruntime/quantization/operators/split.py +63 -0
  73. onnxruntime/quantization/operators/where.py +87 -0
  74. onnxruntime/quantization/preprocess.py +141 -0
  75. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  76. onnxruntime/quantization/qdq_quantizer.py +1477 -0
  77. onnxruntime/quantization/quant_utils.py +1051 -0
  78. onnxruntime/quantization/quantize.py +953 -0
  79. onnxruntime/quantization/registry.py +110 -0
  80. onnxruntime/quantization/shape_inference.py +204 -0
  81. onnxruntime/quantization/static_quantize_runner.py +256 -0
  82. onnxruntime/quantization/tensor_quant_overrides.py +520 -0
  83. onnxruntime/tools/__init__.py +10 -0
  84. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  85. onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
  86. onnxruntime/tools/file_utils.py +47 -0
  87. onnxruntime/tools/logger.py +11 -0
  88. onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
  89. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  90. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
  91. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  92. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  93. onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
  94. onnxruntime/tools/offline_tuning.py +169 -0
  95. onnxruntime/tools/onnx_model_utils.py +416 -0
  96. onnxruntime/tools/onnx_randomizer.py +85 -0
  97. onnxruntime/tools/onnxruntime_test.py +164 -0
  98. onnxruntime/tools/optimize_onnx_model.py +56 -0
  99. onnxruntime/tools/ort_format_model/__init__.py +27 -0
  100. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  140. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  141. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  142. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  143. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  144. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  145. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  146. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  147. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  148. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  149. onnxruntime/tools/ort_format_model/types.py +85 -0
  150. onnxruntime/tools/ort_format_model/utils.py +61 -0
  151. onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
  152. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  153. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  154. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  155. onnxruntime/tools/qnn/add_trans_cast.py +292 -0
  156. onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
  157. onnxruntime/tools/qnn/preprocess.py +165 -0
  158. onnxruntime/tools/reduced_build_config_parser.py +203 -0
  159. onnxruntime/tools/remove_initializer_from_input.py +37 -0
  160. onnxruntime/tools/symbolic_shape_infer.py +3094 -0
  161. onnxruntime/tools/update_onnx_opset.py +31 -0
  162. onnxruntime/transformers/__init__.py +8 -0
  163. onnxruntime/transformers/affinity_helper.py +40 -0
  164. onnxruntime/transformers/benchmark.py +942 -0
  165. onnxruntime/transformers/benchmark_helper.py +643 -0
  166. onnxruntime/transformers/bert_perf_test.py +629 -0
  167. onnxruntime/transformers/bert_test_data.py +641 -0
  168. onnxruntime/transformers/compare_bert_results.py +256 -0
  169. onnxruntime/transformers/constants.py +47 -0
  170. onnxruntime/transformers/convert_generation.py +3605 -0
  171. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  172. onnxruntime/transformers/convert_to_packing_mode.py +385 -0
  173. onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
  174. onnxruntime/transformers/float16.py +501 -0
  175. onnxruntime/transformers/fusion_attention.py +1189 -0
  176. onnxruntime/transformers/fusion_attention_clip.py +340 -0
  177. onnxruntime/transformers/fusion_attention_sam2.py +533 -0
  178. onnxruntime/transformers/fusion_attention_unet.py +1307 -0
  179. onnxruntime/transformers/fusion_attention_vae.py +300 -0
  180. onnxruntime/transformers/fusion_bart_attention.py +435 -0
  181. onnxruntime/transformers/fusion_base.py +141 -0
  182. onnxruntime/transformers/fusion_bias_add.py +57 -0
  183. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  184. onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
  185. onnxruntime/transformers/fusion_conformer_attention.py +222 -0
  186. onnxruntime/transformers/fusion_constant_fold.py +144 -0
  187. onnxruntime/transformers/fusion_embedlayer.py +810 -0
  188. onnxruntime/transformers/fusion_fastgelu.py +492 -0
  189. onnxruntime/transformers/fusion_gelu.py +258 -0
  190. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  191. onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
  192. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  193. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  194. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  195. onnxruntime/transformers/fusion_group_norm.py +180 -0
  196. onnxruntime/transformers/fusion_layernorm.py +489 -0
  197. onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
  198. onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
  199. onnxruntime/transformers/fusion_options.py +340 -0
  200. onnxruntime/transformers/fusion_qordered_attention.py +420 -0
  201. onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
  202. onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
  203. onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
  204. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  205. onnxruntime/transformers/fusion_reshape.py +173 -0
  206. onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
  207. onnxruntime/transformers/fusion_shape.py +109 -0
  208. onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
  209. onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
  210. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  211. onnxruntime/transformers/fusion_transpose.py +167 -0
  212. onnxruntime/transformers/fusion_utils.py +321 -0
  213. onnxruntime/transformers/huggingface_models.py +74 -0
  214. onnxruntime/transformers/import_utils.py +20 -0
  215. onnxruntime/transformers/io_binding_helper.py +487 -0
  216. onnxruntime/transformers/large_model_exporter.py +395 -0
  217. onnxruntime/transformers/machine_info.py +230 -0
  218. onnxruntime/transformers/metrics.py +163 -0
  219. onnxruntime/transformers/models/bart/__init__.py +12 -0
  220. onnxruntime/transformers/models/bart/export.py +98 -0
  221. onnxruntime/transformers/models/bert/__init__.py +12 -0
  222. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  223. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  224. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  225. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
  226. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
  227. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  228. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  229. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  230. onnxruntime/transformers/models/llama/__init__.py +12 -0
  231. onnxruntime/transformers/models/llama/benchmark.py +700 -0
  232. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  233. onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
  234. onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
  235. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  236. onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
  237. onnxruntime/transformers/models/llama/llama_parity.py +343 -0
  238. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  239. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  240. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  241. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  242. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  243. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  244. onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
  245. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  246. onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
  247. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  248. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  249. onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
  250. onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
  251. onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
  252. onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
  253. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  254. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  255. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  256. onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
  257. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
  258. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  259. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  260. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
  261. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  262. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
  263. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
  264. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  265. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
  266. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
  267. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
  268. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
  269. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  270. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  271. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  272. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
  273. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  274. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  275. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  276. onnxruntime/transformers/models/t5/__init__.py +12 -0
  277. onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
  278. onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
  279. onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
  280. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
  281. onnxruntime/transformers/models/t5/t5_helper.py +302 -0
  282. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  283. onnxruntime/transformers/models/whisper/benchmark.py +585 -0
  284. onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
  285. onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
  286. onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
  287. onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
  288. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  289. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
  290. onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
  291. onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
  292. onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
  293. onnxruntime/transformers/onnx_exporter.py +719 -0
  294. onnxruntime/transformers/onnx_model.py +1636 -0
  295. onnxruntime/transformers/onnx_model_bart.py +141 -0
  296. onnxruntime/transformers/onnx_model_bert.py +488 -0
  297. onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
  298. onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
  299. onnxruntime/transformers/onnx_model_clip.py +42 -0
  300. onnxruntime/transformers/onnx_model_conformer.py +32 -0
  301. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  302. onnxruntime/transformers/onnx_model_mmdit.py +112 -0
  303. onnxruntime/transformers/onnx_model_phi.py +929 -0
  304. onnxruntime/transformers/onnx_model_sam2.py +137 -0
  305. onnxruntime/transformers/onnx_model_t5.py +985 -0
  306. onnxruntime/transformers/onnx_model_tnlr.py +226 -0
  307. onnxruntime/transformers/onnx_model_unet.py +258 -0
  308. onnxruntime/transformers/onnx_model_vae.py +42 -0
  309. onnxruntime/transformers/onnx_utils.py +55 -0
  310. onnxruntime/transformers/optimizer.py +620 -0
  311. onnxruntime/transformers/past_helper.py +149 -0
  312. onnxruntime/transformers/profile_result_processor.py +358 -0
  313. onnxruntime/transformers/profiler.py +434 -0
  314. onnxruntime/transformers/quantize_helper.py +76 -0
  315. onnxruntime/transformers/shape_infer_helper.py +121 -0
  316. onnxruntime/transformers/shape_optimizer.py +400 -0
  317. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  318. onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
  319. onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
  320. onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
  321. onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
  322. onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,585 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License. See License.txt in the project root for
4
+ # license information.
5
+ # --------------------------------------------------------------------------
6
+
7
+ import argparse
8
+ import ast
9
+ import datetime
10
+ import gc
11
+ import logging
12
+ import os
13
+ import sys
14
+ import time
15
+
16
+ import numpy as np
17
+ import psutil
18
+ import torch
19
+ import whisper
20
+ from benchmark_helper import measure_memory, setup_logger
21
+ from onnxruntime_extensions import get_library_path
22
+ from optimum.onnxruntime import ORTModelForSpeechSeq2Seq
23
+ from torch.profiler import ProfilerActivity, profile, record_function
24
+ from tqdm import trange
25
+ from transformers import AutoModelForSpeechSeq2Seq, WhisperConfig, WhisperProcessor
26
+
27
+ import onnxruntime as ort
28
+
29
+ logger = logging.getLogger(__name__)
30
+
31
+
32
+ def get_inputs(args: argparse.Namespace):
33
+ if args.benchmark_type not in {"hf-pt-eager", "hf-pt-compile", "hf-ort", "ort"}:
34
+ raise Exception("Unable to auto-detect inputs for provided model")
35
+
36
+ def load_via_ffmpeg():
37
+ audio = whisper.load_audio(args.audio_path)
38
+ audio = whisper.pad_or_trim(audio)
39
+ return audio
40
+
41
+ def load_via_numpy():
42
+ with open(args.audio_path, "rb") as f:
43
+ audio = np.asarray(list(f.read()), dtype=np.uint8)
44
+ audio = np.array([audio])
45
+ return audio
46
+
47
+ inputs = {
48
+ "max_length": args.max_length,
49
+ "min_length": args.min_length,
50
+ "num_beams": args.num_beams,
51
+ "num_return_sequences": args.num_return_sequences,
52
+ "length_penalty": args.length_penalty,
53
+ "repetition_penalty": args.repetition_penalty,
54
+ }
55
+ if args.benchmark_type == "ort":
56
+ # convert_to_onnx export or ONNX E2E solution created by Olive
57
+ for k, v in inputs.items():
58
+ inputs[k] = np.array([v], dtype=np.float32 if "penalty" in k else np.int32)
59
+ if args.has_decoder_input_ids:
60
+ inputs["decoder_input_ids"] = np.array([args.decoder_input_ids], dtype=np.int32)
61
+ if args.has_logits_processor:
62
+ inputs["logits_processor"] = np.array([args.logits_processor], dtype=np.int32)
63
+ if args.has_temperature:
64
+ inputs["temperature"] = np.array([args.temperature], dtype=np.float32)
65
+
66
+ # Measure time taken to load audio file
67
+ logger.info(f"Load audio: {args.audio_path}")
68
+ load_audio_fn = lambda onnx_e2e: load_via_numpy() if onnx_e2e else load_via_ffmpeg() # noqa: E731
69
+ time_fn(args, load_audio_fn, args.has_audio_stream)
70
+ audio_data = load_audio_fn(args.has_audio_stream)
71
+
72
+ if args.has_audio_stream:
73
+ # ONNX E2E solution created by Olive
74
+ inputs["audio_stream"] = audio_data
75
+ return inputs
76
+
77
+ # Measure time taken to get input features
78
+ logger.info("Feature extraction: ")
79
+ return_type = "np" if args.benchmark_type == "ort" else "pt"
80
+ processor_fn = lambda audio: args.processor.feature_extractor( # noqa: E731
81
+ [audio], return_tensors=return_type, sampling_rate=args.sampling_rate
82
+ ).input_features
83
+ time_fn(args, processor_fn, audio_data)
84
+ input_features = processor_fn(audio_data)
85
+
86
+ if args.benchmark_type == "ort":
87
+ # convert_to_onnx export
88
+ inputs["input_features"] = input_features
89
+ return inputs
90
+
91
+ inputs["inputs"] = input_features.to(
92
+ dtype=torch.float16 if args.use_fp16 else torch.float32, device=args.target_device
93
+ )
94
+ inputs["no_repeat_ngram_size"] = args.no_repeat_ngram_size
95
+ inputs["early_stopping"] = True
96
+ inputs["use_cache"] = True
97
+
98
+ if args.decoder_input_ids:
99
+ inputs["forced_decoder_ids"] = args.decoder_input_ids
100
+
101
+ return inputs
102
+
103
+
104
+ def get_model(args: argparse.Namespace):
105
+ model, sess_options = None, None
106
+ start_time, end_time = None, None
107
+
108
+ # There are multiple sources that the model could come from:
109
+ # 1) Benchmark Whisper from Hugging Face
110
+ # 2) Benchmark Whisper ONNX model from Optimum export (without pre/post processing)
111
+ # 3) Benchmark Whisper ONNX E2E model from Olive (with pre/post processing)
112
+
113
+ if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile"}:
114
+ source = args.hf_pt_model_path if args.hf_pt_model_path else args.model_name
115
+ start_time = time.time()
116
+ model = AutoModelForSpeechSeq2Seq.from_pretrained(
117
+ source,
118
+ torch_dtype=torch.float16 if args.use_fp16 else torch.float32,
119
+ use_cache=True,
120
+ ).to(args.target_device)
121
+ end_time = time.time()
122
+
123
+ if args.benchmark_type == "hf-pt-compile":
124
+ model = torch.compile(model)
125
+
126
+ elif args.benchmark_type in {"hf-ort", "ort"}:
127
+ sess_options = ort.SessionOptions()
128
+ sess_options.enable_profiling = args.profile
129
+ sess_options.register_custom_ops_library(get_library_path())
130
+ if args.verbose:
131
+ sess_options.log_verbosity_level = 1
132
+ sess_options.log_severity_level = 1
133
+
134
+ else:
135
+ raise Exception(f"Cannot recognize {args.benchmark_type}")
136
+
137
+ if args.benchmark_type == "hf-ort":
138
+ # Optimum export
139
+ provider = args.execution_provider[0] if type(args.execution_provider) is tuple else args.execution_provider
140
+ provider_options = args.execution_provider[1] if type(args.execution_provider) is tuple else None
141
+
142
+ start_time = time.time()
143
+ model = ORTModelForSpeechSeq2Seq.from_pretrained(
144
+ args.hf_ort_dir_path,
145
+ provider=provider,
146
+ provider_options=provider_options,
147
+ session_options=sess_options,
148
+ use_io_binding=True, # Avoid memory copy overhead
149
+ )
150
+ end_time = time.time()
151
+
152
+ if args.benchmark_type == "ort":
153
+ # convert_to_onnx.py export
154
+ logger.info(f"Loading model from {args.ort_model_path}")
155
+ start_time = time.time()
156
+ model = ort.InferenceSession(
157
+ args.ort_model_path,
158
+ sess_options,
159
+ providers=[args.execution_provider],
160
+ )
161
+ end_time = time.time()
162
+
163
+ logger.info(f"Loaded model in {end_time - start_time} s")
164
+
165
+ return model
166
+
167
+
168
+ def time_fn(args, fn, inputs):
169
+ warmup_inputs = inputs[0] if type(inputs) is tuple else inputs
170
+ benchmark_inputs = inputs[1] if type(inputs) is tuple else inputs
171
+ torch_device = torch.device(args.target_device)
172
+
173
+ # Warm up
174
+ warmup_range = (
175
+ range(args.warmup_runs)
176
+ if args.benchmark_type == "ort"
177
+ else trange(args.warmup_runs, file=sys.stdout, desc="Warm up")
178
+ )
179
+
180
+ if args.verbose:
181
+ outputs = fn(warmup_inputs)
182
+ logger.info(outputs)
183
+
184
+ for _ in warmup_range:
185
+ fn(warmup_inputs)
186
+
187
+ # Benchmark
188
+ if args.device != "cpu":
189
+ torch.cuda.synchronize(torch_device)
190
+ start_time = time.time()
191
+
192
+ bench_range = (
193
+ range(args.num_runs)
194
+ if args.benchmark_type == "ort"
195
+ else trange(args.num_runs, file=sys.stdout, desc="Benchmark")
196
+ )
197
+ for _ in bench_range:
198
+ fn(benchmark_inputs)
199
+
200
+ if args.device != "cpu":
201
+ torch.cuda.synchronize(torch_device)
202
+ end_time = time.time()
203
+
204
+ # Newline print after trange in order to print metrics on new lines without progress bar on same line
205
+ if args.benchmark_type != "ort":
206
+ logger.info("")
207
+
208
+ batch_size = 1
209
+ latency = (end_time - start_time) / args.num_runs
210
+ throughput = batch_size / latency
211
+
212
+ logger.info(f"Latency: {latency} s")
213
+ logger.info(f"Throughput: {throughput} qps")
214
+ return
215
+
216
+
217
+ def profile_fn(args, fn, inputs, inputs_type):
218
+ # Filename prefix format:
219
+ # "<benchmark-type>-<precision>-<device>_<inference-step>_<inputs-type>_<current-time>"
220
+ prefix = f"{args.benchmark_type.lower()}-{args.precision}-{args.device}_{fn.__name__.replace('_', '-')}_{inputs_type}_{datetime.datetime.now():%Y-%m-%d_%H:%M:%S}"
221
+ filename = None
222
+
223
+ if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile"}:
224
+ # Profile PyTorch kernels
225
+ with profile( # noqa: SIM117
226
+ activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True, profile_memory=True
227
+ ) as prof:
228
+ with record_function("model_inference"):
229
+ fn(inputs)
230
+ prof_data = prof.key_averages(group_by_stack_n=5).table(sort_by=args.pt_filter_by, row_limit=args.pt_num_rows)
231
+
232
+ filename = os.path.join(args.log_folder, f"{prefix}.log")
233
+ with open(filename, "w") as f:
234
+ f.write(prof_data)
235
+
236
+ else:
237
+ # Profile ORT kernels
238
+ fn(inputs)
239
+
240
+ # Set new log name for ORT profile log generated
241
+ filename = f"{prefix}.json"
242
+
243
+ return filename
244
+
245
+
246
+ def measure_fn(args, fn, inputs):
247
+ # Measure CPU usage
248
+ pid = os.getpid()
249
+ process = psutil.Process(pid)
250
+ process.cpu_percent(interval=0.1)
251
+
252
+ fn(inputs)
253
+ logger.info(f"CPU usage: {process.cpu_percent(interval=None)}%")
254
+
255
+ # Measure memory usage
256
+ gc.collect()
257
+ torch.cuda.empty_cache()
258
+ measure_memory(is_gpu=(args.device != "cpu"), func=lambda: fn(inputs), monitor_type=args.monitor_type)
259
+
260
+ # Flush output so memory usage is printed
261
+ sys.stdout.flush()
262
+
263
+
264
+ def run_hf_inference(args, inputs, model):
265
+ # Inference steps to measure
266
+ def get_pred_ids(inputs):
267
+ # Inference pass with predicted token ids generation
268
+ predicted_ids = model.generate(**inputs)
269
+ return predicted_ids
270
+
271
+ def gen_and_dec(inputs):
272
+ # Inference pass with generation and decoding
273
+ predicted_ids = get_pred_ids(inputs)
274
+ transcription = []
275
+ for _ in range(args.num_return_sequences):
276
+ transcription.append(args.processor.batch_decode(predicted_ids, skip_special_tokens=True)[0])
277
+ return predicted_ids, transcription
278
+
279
+ # Examples of other inference steps that can be measured:
280
+ # To use, uncomment the function and assign it to `generate_fn`
281
+
282
+ # def get_logits(inputs):
283
+ # # Inference pass without decoding
284
+ # outputs = model(**inputs)
285
+ # return outputs
286
+
287
+ generate_fn = gen_and_dec
288
+
289
+ if args.benchmark_type == "hf-pt-compile":
290
+ # Run forward pass once with each set of inputs to process through Dynamo
291
+ generate_fn(inputs)
292
+
293
+ if args.profile:
294
+ new_logname = profile_fn(args, generate_fn, inputs, "gen-and-dec")
295
+ if args.benchmark_type == "hf-ort":
296
+ # Rename log files per model component and turn profiling off to stop appending to log
297
+ new_prefix = new_logname[: -len(".json")]
298
+
299
+ old_logname = model.encoder.session.end_profiling()
300
+ new_logname = new_prefix + "-encoder.json"
301
+ if os.path.isfile(old_logname):
302
+ logger.warning(f"Renaming {old_logname} to {new_logname}")
303
+ os.rename(old_logname, os.path.join(args.log_folder, new_logname))
304
+
305
+ old_logname = model.decoder.session.end_profiling()
306
+ new_logname = new_prefix + "-decoder.json"
307
+ if os.path.isfile(old_logname):
308
+ logger.warning(f"Renaming {old_logname} to {new_logname}")
309
+ os.rename(old_logname, os.path.join(args.log_folder, new_logname))
310
+
311
+ old_logname = model.decoder_with_past.session.end_profiling()
312
+ new_logname = new_prefix + "-decoder-with-past.json"
313
+ if os.path.isfile(old_logname):
314
+ logger.warning(f"Renaming {old_logname} to {new_logname}")
315
+ os.rename(old_logname, os.path.join(args.log_folder, new_logname))
316
+
317
+ return
318
+
319
+ # PyTorch evaluations
320
+ logger.info("\nEvaluating PyTorch...")
321
+ time_fn(args, generate_fn, inputs)
322
+ predicted_ids, transcription = generate_fn(inputs)
323
+ logger.info(f"Generated token length: {len(predicted_ids[0])} tokens")
324
+ logger.info(f"Transcription: {transcription[0]}")
325
+ measure_fn(args, generate_fn, inputs)
326
+
327
+
328
+ def run_ort_inference(args, inputs, model):
329
+ def prepare_ort_inputs(inputs, warmup=False):
330
+ # Check that all model inputs will be provided
331
+ model_inputs = {model_input.name for model_input in model.get_inputs()}
332
+ user_inputs = set(inputs.keys())
333
+ missing_inputs = model_inputs - user_inputs
334
+ if len(missing_inputs):
335
+ logger.error(f"The following model inputs are missing: {missing_inputs}")
336
+ raise Exception("There are missing inputs to the model. Please add them and try again.")
337
+
338
+ # Remove unnecessary inputs from model inputs
339
+ unnecessary_inputs = user_inputs - model_inputs
340
+ if len(unnecessary_inputs):
341
+ for unnecessary_input in unnecessary_inputs:
342
+ logger.info(f"Removing unnecessary input '{unnecessary_input}' from user provided inputs")
343
+ del inputs[unnecessary_input]
344
+
345
+ # Add IO bindings for non-CPU execution providers
346
+ if args.device != "cpu":
347
+ io_binding = model.io_binding()
348
+ for k, v in inputs.items():
349
+ io_binding.bind_cpu_input(k, v)
350
+ for output in model.get_outputs():
351
+ io_binding.bind_output(output.name, device_type=args.device, device_id=args.device_id)
352
+ return io_binding
353
+
354
+ return inputs
355
+
356
+ def with_io_binding(io_binding):
357
+ # Inference pass with IO binding
358
+ model.run_with_iobinding(io_binding)
359
+ return io_binding
360
+
361
+ def without_io_binding(inputs):
362
+ # Inference pass without IO binding
363
+ outputs = model.run(None, inputs)
364
+ return outputs
365
+
366
+ def handle_output(output):
367
+ if args.eos_token_id in output:
368
+ first_end = np.where(output == args.eos_token_id)[0][0]
369
+ return output[: first_end + 1]
370
+
371
+ return output
372
+
373
+ generate_fn = with_io_binding if args.device != "cpu" else without_io_binding
374
+ ort_inputs = prepare_ort_inputs(inputs)
375
+
376
+ if args.profile:
377
+ new_logname = profile_fn(args, generate_fn, ort_inputs, "e2e")
378
+
379
+ # Turn profiling off to stop appending to log file
380
+ old_logname = model.end_profiling()
381
+ logger.warning(f"Renaming {old_logname} to {new_logname}")
382
+ os.rename(old_logname, os.path.join(args.log_folder, new_logname))
383
+
384
+ return
385
+
386
+ # ORT evaluation
387
+ logger.info("\nEvaluating ONNX Runtime...")
388
+ ort_evaluate_inputs = ort_inputs
389
+
390
+ time_fn(args, generate_fn, ort_evaluate_inputs)
391
+ ort_outputs = generate_fn(ort_inputs)
392
+ if args.device != "cpu":
393
+ ort_outputs = ort_outputs.copy_outputs_to_cpu()
394
+ ort_outputs = ort_outputs[0]
395
+
396
+ if args.has_audio_stream:
397
+ # ONNX E2E model from Olive produces transcribed output
398
+ logger.info(f"Transcription: {ort_outputs[0][0]}")
399
+ else:
400
+ # convert_to_onnx model produces generated ids
401
+ actual_output = handle_output(ort_outputs[0][0])
402
+ logger.info(f"Generated token length: {len(actual_output)} tokens")
403
+ transcription = args.processor.batch_decode(ort_outputs[0], skip_special_tokens=True)[0]
404
+ # print to stdout as the output for comparison
405
+ print(f"{transcription}")
406
+
407
+ measure_fn(args, generate_fn, ort_inputs)
408
+
409
+
410
+ def run_inference(args, inputs, model):
411
+ if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile", "hf-ort"}:
412
+ run_hf_inference(args, inputs, model)
413
+ elif args.benchmark_type == "ort":
414
+ run_ort_inference(args, inputs, model)
415
+ else:
416
+ raise Exception(f"Cannot recognize {args.benchmark_type}")
417
+
418
+
419
+ def parse_args():
420
+ parser = argparse.ArgumentParser()
421
+
422
+ parser.add_argument(
423
+ "-bt",
424
+ "--benchmark-type",
425
+ type=str,
426
+ required=True,
427
+ choices=["hf-pt-eager", "hf-pt-compile", "hf-ort", "ort"],
428
+ )
429
+
430
+ parser.add_argument(
431
+ "-m",
432
+ "--model-name",
433
+ type=str,
434
+ required=True,
435
+ help="Hugging Face name of model (e.g. 'openai/whisper-large-v2')",
436
+ )
437
+ parser.add_argument(
438
+ "-p",
439
+ "--precision",
440
+ type=str,
441
+ required=True,
442
+ default="fp32",
443
+ choices=["int4", "int8", "fp16", "fp32"],
444
+ help="Precision for model. For ONNX models, the model's precision should be set before running this script.",
445
+ )
446
+
447
+ parser.add_argument(
448
+ "--hf-pt-model-path",
449
+ type=str,
450
+ default="",
451
+ help="Path to directory containing all PyTorch files (e.g. tokenizer, PyTorch model)",
452
+ )
453
+ parser.add_argument(
454
+ "--hf-ort-dir-path",
455
+ type=str,
456
+ default="",
457
+ help="Path to directory containing all ONNX files (e.g. tokenizer, encoder, decoder, decoder_with_past)",
458
+ )
459
+ parser.add_argument(
460
+ "--ort-model-path",
461
+ type=str,
462
+ default="",
463
+ help="Path to ONNX model",
464
+ )
465
+
466
+ # Args for running and evaluating the model
467
+ parser.add_argument("-a", "--audio-path", type=str, required=True, help="Path to audio file for E2E evaluation")
468
+ parser.add_argument(
469
+ "-d",
470
+ "--device",
471
+ type=str,
472
+ default="cuda" if torch.cuda.is_available() else "cpu",
473
+ choices=["cpu", "cuda"],
474
+ )
475
+ parser.add_argument("-id", "--device-id", type=int, default=0)
476
+ parser.add_argument("-w", "--warmup-runs", type=int, default=5)
477
+ parser.add_argument("-n", "--num-runs", type=int, default=10)
478
+ parser.add_argument("--seed", type=int, default=2)
479
+
480
+ # Optional args:
481
+ parser.add_argument("--sampling-rate", type=int, default=16000, help="Sampling rate for audio (in Hz)")
482
+
483
+ # Args for decoding logic
484
+ # Required args:
485
+ parser.add_argument("--max-length", type=int, default=448)
486
+ parser.add_argument("--min-length", type=int, default=0)
487
+ parser.add_argument("--num-beams", type=int, default=1)
488
+ parser.add_argument("--num-return-sequences", type=int, default=1)
489
+ parser.add_argument("--length-penalty", type=float, default=1.0)
490
+ parser.add_argument("--repetition-penalty", type=float, default=1.0)
491
+ parser.add_argument("--no-repeat-ngram-size", type=int, default=3)
492
+
493
+ # Optional args for E2E solution:
494
+ parser.add_argument(
495
+ "--decoder-input-ids",
496
+ type=str,
497
+ default="[]",
498
+ help="The forced decoder ids for generation. Format is [start token, timestamp token, language token, task token]. Default is [start token]. See `decoder_input_ids` in https://github.com/microsoft/Olive/tree/main/examples/whisper for details.",
499
+ )
500
+ parser.add_argument(
501
+ "--logits-processor",
502
+ type=int,
503
+ default=1,
504
+ help="Whether to use timestamps logits processor or not (0 for false, 1 for true).",
505
+ )
506
+ parser.add_argument(
507
+ "--temperature",
508
+ type=float,
509
+ default=1.0,
510
+ help="Temperature value for generation.",
511
+ )
512
+
513
+ # Args for accessing detailed info
514
+ parser.add_argument("--profile", default=False, action="store_true")
515
+ parser.add_argument(
516
+ "--pt-filter-by", type=str, default="self_cpu_time_total", help="What to filter PyTorch profiler by"
517
+ )
518
+ parser.add_argument("--pt-num-rows", type=int, default=1000, help="Number of rows for PyTorch profiler to display")
519
+ parser.add_argument("--verbose", default=False, action="store_true")
520
+ parser.add_argument("--log-folder", type=str, default=os.path.join("."), help="Folder to cache log files")
521
+
522
+ args = parser.parse_args()
523
+
524
+ # Set seed properties
525
+ np.random.seed(args.seed)
526
+ torch.manual_seed(args.seed)
527
+
528
+ args.monitor_type = args.device
529
+ # Set runtime properties
530
+ if "ort" in args.benchmark_type:
531
+ args.execution_provider = f"{args.device.upper()}ExecutionProvider"
532
+ if args.execution_provider == "CUDAExecutionProvider":
533
+ args.execution_provider = (args.execution_provider, {"device_id": args.device_id})
534
+
535
+ # Check that model paths have been specified for any benchmarking with ORT
536
+ if args.benchmark_type == "hf-ort":
537
+ assert args.hf_ort_dir_path, "Please specify a path to `--hf-ort-dir-path`"
538
+ if args.benchmark_type == "ort":
539
+ assert args.ort_model_path, "Please specify a path to `--ort-model-path`"
540
+
541
+ # Convert decoder_input_ids string to list of ids
542
+ # (e.g. "[1, 50257]" for Hugging Face or "[50257]" for ORT)
543
+ args.decoder_input_ids = ast.literal_eval(args.decoder_input_ids)
544
+
545
+ return args
546
+
547
+
548
+ def main():
549
+ args = parse_args()
550
+ setup_logger(args.verbose)
551
+ logger.info(args.__dict__)
552
+ torch.backends.cudnn.benchmark = True
553
+
554
+ config = WhisperConfig.from_pretrained(args.model_name)
555
+ processor = WhisperProcessor.from_pretrained(args.model_name)
556
+ target_device = f"cuda:{args.device_id}" if args.device != "cpu" else args.device
557
+ use_fp16 = args.precision == "fp16" or (args.precision in {"int8", "int4"} and args.device != "cpu")
558
+
559
+ setattr(args, "processor", processor) # noqa: B010
560
+ setattr(args, "target_device", target_device) # noqa: B010
561
+ setattr(args, "use_fp16", use_fp16) # noqa: B010
562
+ setattr(args, "has_audio_stream", False) # noqa: B010
563
+ setattr(args, "eos_token_id", config.eos_token_id) # noqa: B010
564
+
565
+ logger.info(f"Forced decoder prompt ids: {args.decoder_input_ids}")
566
+
567
+ # Measure cost to transcribe audio
568
+ model = get_model(args)
569
+ if args.benchmark_type == "ort":
570
+ # Check for optional inputs that could have been added during export
571
+ ort_model_inputs = {model_input.name for model_input in model.get_inputs()}
572
+ args.has_audio_stream = "audio_stream" in ort_model_inputs
573
+ setattr(args, "has_decoder_input_ids", "decoder_input_ids" in ort_model_inputs) # noqa: B010
574
+ setattr(args, "has_logits_processor", "logits_processor" in ort_model_inputs) # noqa: B010
575
+ setattr(args, "has_temperature", "temperature" in ort_model_inputs) # noqa: B010
576
+
577
+ if args.decoder_input_ids == []:
578
+ args.decoder_input_ids = [config.decoder_start_token_id]
579
+
580
+ inputs = get_inputs(args)
581
+ run_inference(args, inputs, model)
582
+
583
+
584
+ if __name__ == "__main__":
585
+ main()