onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxruntime/LICENSE +21 -0
- onnxruntime/Privacy.md +21 -0
- onnxruntime/ThirdPartyNotices.txt +6121 -0
- onnxruntime/__init__.py +418 -0
- onnxruntime/backend/__init__.py +6 -0
- onnxruntime/backend/backend.py +175 -0
- onnxruntime/backend/backend_rep.py +52 -0
- onnxruntime/capi/DirectML.dll +0 -0
- onnxruntime/capi/__init__.py +4 -0
- onnxruntime/capi/_ld_preload.py +7 -0
- onnxruntime/capi/_pybind_state.py +33 -0
- onnxruntime/capi/build_and_package_info.py +2 -0
- onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
- onnxruntime/capi/onnxruntime.dll +0 -0
- onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
- onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
- onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
- onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
- onnxruntime/capi/onnxruntime_validation.py +154 -0
- onnxruntime/capi/version_info.py +2 -0
- onnxruntime/datasets/__init__.py +18 -0
- onnxruntime/datasets/logreg_iris.onnx +0 -0
- onnxruntime/datasets/mul_1.onnx +0 -0
- onnxruntime/datasets/sigmoid.onnx +13 -0
- onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
- onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
- onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
- onnxruntime/quantization/__init__.py +19 -0
- onnxruntime/quantization/base_quantizer.py +529 -0
- onnxruntime/quantization/calibrate.py +1267 -0
- onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
- onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
- onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
- onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
- onnxruntime/quantization/fusions/__init__.py +4 -0
- onnxruntime/quantization/fusions/fusion.py +311 -0
- onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
- onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
- onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
- onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
- onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
- onnxruntime/quantization/neural_compressor/__init__.py +1 -0
- onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
- onnxruntime/quantization/neural_compressor/util.py +80 -0
- onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
- onnxruntime/quantization/onnx_model.py +600 -0
- onnxruntime/quantization/onnx_quantizer.py +1163 -0
- onnxruntime/quantization/operators/__init__.py +2 -0
- onnxruntime/quantization/operators/activation.py +119 -0
- onnxruntime/quantization/operators/argmax.py +18 -0
- onnxruntime/quantization/operators/attention.py +73 -0
- onnxruntime/quantization/operators/base_operator.py +26 -0
- onnxruntime/quantization/operators/binary_op.py +72 -0
- onnxruntime/quantization/operators/concat.py +62 -0
- onnxruntime/quantization/operators/conv.py +260 -0
- onnxruntime/quantization/operators/direct_q8.py +78 -0
- onnxruntime/quantization/operators/embed_layernorm.py +121 -0
- onnxruntime/quantization/operators/gather.py +64 -0
- onnxruntime/quantization/operators/gavgpool.py +62 -0
- onnxruntime/quantization/operators/gemm.py +172 -0
- onnxruntime/quantization/operators/lstm.py +121 -0
- onnxruntime/quantization/operators/matmul.py +231 -0
- onnxruntime/quantization/operators/maxpool.py +34 -0
- onnxruntime/quantization/operators/norm.py +40 -0
- onnxruntime/quantization/operators/pad.py +172 -0
- onnxruntime/quantization/operators/pooling.py +67 -0
- onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
- onnxruntime/quantization/operators/resize.py +34 -0
- onnxruntime/quantization/operators/softmax.py +74 -0
- onnxruntime/quantization/operators/split.py +63 -0
- onnxruntime/quantization/operators/where.py +87 -0
- onnxruntime/quantization/preprocess.py +141 -0
- onnxruntime/quantization/qdq_loss_debug.py +389 -0
- onnxruntime/quantization/qdq_quantizer.py +1477 -0
- onnxruntime/quantization/quant_utils.py +1051 -0
- onnxruntime/quantization/quantize.py +953 -0
- onnxruntime/quantization/registry.py +110 -0
- onnxruntime/quantization/shape_inference.py +204 -0
- onnxruntime/quantization/static_quantize_runner.py +256 -0
- onnxruntime/quantization/tensor_quant_overrides.py +520 -0
- onnxruntime/tools/__init__.py +10 -0
- onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
- onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
- onnxruntime/tools/file_utils.py +47 -0
- onnxruntime/tools/logger.py +11 -0
- onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
- onnxruntime/tools/mobile_helpers/__init__.py +0 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
- onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
- onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
- onnxruntime/tools/offline_tuning.py +169 -0
- onnxruntime/tools/onnx_model_utils.py +416 -0
- onnxruntime/tools/onnx_randomizer.py +85 -0
- onnxruntime/tools/onnxruntime_test.py +164 -0
- onnxruntime/tools/optimize_onnx_model.py +56 -0
- onnxruntime/tools/ort_format_model/__init__.py +27 -0
- onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
- onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
- onnxruntime/tools/ort_format_model/types.py +85 -0
- onnxruntime/tools/ort_format_model/utils.py +61 -0
- onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
- onnxruntime/tools/pytorch_export_helpers.py +131 -0
- onnxruntime/tools/qdq_helpers/__init__.py +0 -0
- onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
- onnxruntime/tools/qnn/add_trans_cast.py +292 -0
- onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
- onnxruntime/tools/qnn/preprocess.py +165 -0
- onnxruntime/tools/reduced_build_config_parser.py +203 -0
- onnxruntime/tools/remove_initializer_from_input.py +37 -0
- onnxruntime/tools/symbolic_shape_infer.py +3094 -0
- onnxruntime/tools/update_onnx_opset.py +31 -0
- onnxruntime/transformers/__init__.py +8 -0
- onnxruntime/transformers/affinity_helper.py +40 -0
- onnxruntime/transformers/benchmark.py +942 -0
- onnxruntime/transformers/benchmark_helper.py +643 -0
- onnxruntime/transformers/bert_perf_test.py +629 -0
- onnxruntime/transformers/bert_test_data.py +641 -0
- onnxruntime/transformers/compare_bert_results.py +256 -0
- onnxruntime/transformers/constants.py +47 -0
- onnxruntime/transformers/convert_generation.py +3605 -0
- onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
- onnxruntime/transformers/convert_to_packing_mode.py +385 -0
- onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
- onnxruntime/transformers/float16.py +501 -0
- onnxruntime/transformers/fusion_attention.py +1189 -0
- onnxruntime/transformers/fusion_attention_clip.py +340 -0
- onnxruntime/transformers/fusion_attention_sam2.py +533 -0
- onnxruntime/transformers/fusion_attention_unet.py +1307 -0
- onnxruntime/transformers/fusion_attention_vae.py +300 -0
- onnxruntime/transformers/fusion_bart_attention.py +435 -0
- onnxruntime/transformers/fusion_base.py +141 -0
- onnxruntime/transformers/fusion_bias_add.py +57 -0
- onnxruntime/transformers/fusion_biasgelu.py +66 -0
- onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
- onnxruntime/transformers/fusion_conformer_attention.py +222 -0
- onnxruntime/transformers/fusion_constant_fold.py +144 -0
- onnxruntime/transformers/fusion_embedlayer.py +810 -0
- onnxruntime/transformers/fusion_fastgelu.py +492 -0
- onnxruntime/transformers/fusion_gelu.py +258 -0
- onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
- onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
- onnxruntime/transformers/fusion_gpt_attention.py +546 -0
- onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
- onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
- onnxruntime/transformers/fusion_group_norm.py +180 -0
- onnxruntime/transformers/fusion_layernorm.py +489 -0
- onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
- onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
- onnxruntime/transformers/fusion_options.py +340 -0
- onnxruntime/transformers/fusion_qordered_attention.py +420 -0
- onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
- onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
- onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
- onnxruntime/transformers/fusion_quickgelu.py +74 -0
- onnxruntime/transformers/fusion_reshape.py +173 -0
- onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
- onnxruntime/transformers/fusion_shape.py +109 -0
- onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
- onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
- onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
- onnxruntime/transformers/fusion_transpose.py +167 -0
- onnxruntime/transformers/fusion_utils.py +321 -0
- onnxruntime/transformers/huggingface_models.py +74 -0
- onnxruntime/transformers/import_utils.py +20 -0
- onnxruntime/transformers/io_binding_helper.py +487 -0
- onnxruntime/transformers/large_model_exporter.py +395 -0
- onnxruntime/transformers/machine_info.py +230 -0
- onnxruntime/transformers/metrics.py +163 -0
- onnxruntime/transformers/models/bart/__init__.py +12 -0
- onnxruntime/transformers/models/bart/export.py +98 -0
- onnxruntime/transformers/models/bert/__init__.py +12 -0
- onnxruntime/transformers/models/bert/eval_squad.py +329 -0
- onnxruntime/transformers/models/gpt2/__init__.py +12 -0
- onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
- onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
- onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
- onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
- onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
- onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
- onnxruntime/transformers/models/llama/__init__.py +12 -0
- onnxruntime/transformers/models/llama/benchmark.py +700 -0
- onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
- onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
- onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
- onnxruntime/transformers/models/llama/dist_settings.py +57 -0
- onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
- onnxruntime/transformers/models/llama/llama_parity.py +343 -0
- onnxruntime/transformers/models/llama/llama_torch.py +47 -0
- onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
- onnxruntime/transformers/models/longformer/__init__.py +12 -0
- onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
- onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
- onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
- onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
- onnxruntime/transformers/models/phi2/__init__.py +12 -0
- onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
- onnxruntime/transformers/models/phi2/inference_example.py +414 -0
- onnxruntime/transformers/models/sam2/__init__.py +12 -0
- onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
- onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
- onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
- onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
- onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
- onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
- onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
- onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
- onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
- onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
- onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
- onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
- onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
- onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
- onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
- onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
- onnxruntime/transformers/models/t5/__init__.py +12 -0
- onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
- onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
- onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
- onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
- onnxruntime/transformers/models/t5/t5_helper.py +302 -0
- onnxruntime/transformers/models/whisper/__init__.py +12 -0
- onnxruntime/transformers/models/whisper/benchmark.py +585 -0
- onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
- onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
- onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
- onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
- onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
- onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
- onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
- onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
- onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
- onnxruntime/transformers/onnx_exporter.py +719 -0
- onnxruntime/transformers/onnx_model.py +1636 -0
- onnxruntime/transformers/onnx_model_bart.py +141 -0
- onnxruntime/transformers/onnx_model_bert.py +488 -0
- onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
- onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
- onnxruntime/transformers/onnx_model_clip.py +42 -0
- onnxruntime/transformers/onnx_model_conformer.py +32 -0
- onnxruntime/transformers/onnx_model_gpt2.py +101 -0
- onnxruntime/transformers/onnx_model_mmdit.py +112 -0
- onnxruntime/transformers/onnx_model_phi.py +929 -0
- onnxruntime/transformers/onnx_model_sam2.py +137 -0
- onnxruntime/transformers/onnx_model_t5.py +985 -0
- onnxruntime/transformers/onnx_model_tnlr.py +226 -0
- onnxruntime/transformers/onnx_model_unet.py +258 -0
- onnxruntime/transformers/onnx_model_vae.py +42 -0
- onnxruntime/transformers/onnx_utils.py +55 -0
- onnxruntime/transformers/optimizer.py +620 -0
- onnxruntime/transformers/past_helper.py +149 -0
- onnxruntime/transformers/profile_result_processor.py +358 -0
- onnxruntime/transformers/profiler.py +434 -0
- onnxruntime/transformers/quantize_helper.py +76 -0
- onnxruntime/transformers/shape_infer_helper.py +121 -0
- onnxruntime/transformers/shape_optimizer.py +400 -0
- onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
- onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
- onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
- onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
- onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
- onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,585 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License. See License.txt in the project root for
|
|
4
|
+
# license information.
|
|
5
|
+
# --------------------------------------------------------------------------
|
|
6
|
+
|
|
7
|
+
import argparse
|
|
8
|
+
import ast
|
|
9
|
+
import datetime
|
|
10
|
+
import gc
|
|
11
|
+
import logging
|
|
12
|
+
import os
|
|
13
|
+
import sys
|
|
14
|
+
import time
|
|
15
|
+
|
|
16
|
+
import numpy as np
|
|
17
|
+
import psutil
|
|
18
|
+
import torch
|
|
19
|
+
import whisper
|
|
20
|
+
from benchmark_helper import measure_memory, setup_logger
|
|
21
|
+
from onnxruntime_extensions import get_library_path
|
|
22
|
+
from optimum.onnxruntime import ORTModelForSpeechSeq2Seq
|
|
23
|
+
from torch.profiler import ProfilerActivity, profile, record_function
|
|
24
|
+
from tqdm import trange
|
|
25
|
+
from transformers import AutoModelForSpeechSeq2Seq, WhisperConfig, WhisperProcessor
|
|
26
|
+
|
|
27
|
+
import onnxruntime as ort
|
|
28
|
+
|
|
29
|
+
logger = logging.getLogger(__name__)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def get_inputs(args: argparse.Namespace):
|
|
33
|
+
if args.benchmark_type not in {"hf-pt-eager", "hf-pt-compile", "hf-ort", "ort"}:
|
|
34
|
+
raise Exception("Unable to auto-detect inputs for provided model")
|
|
35
|
+
|
|
36
|
+
def load_via_ffmpeg():
|
|
37
|
+
audio = whisper.load_audio(args.audio_path)
|
|
38
|
+
audio = whisper.pad_or_trim(audio)
|
|
39
|
+
return audio
|
|
40
|
+
|
|
41
|
+
def load_via_numpy():
|
|
42
|
+
with open(args.audio_path, "rb") as f:
|
|
43
|
+
audio = np.asarray(list(f.read()), dtype=np.uint8)
|
|
44
|
+
audio = np.array([audio])
|
|
45
|
+
return audio
|
|
46
|
+
|
|
47
|
+
inputs = {
|
|
48
|
+
"max_length": args.max_length,
|
|
49
|
+
"min_length": args.min_length,
|
|
50
|
+
"num_beams": args.num_beams,
|
|
51
|
+
"num_return_sequences": args.num_return_sequences,
|
|
52
|
+
"length_penalty": args.length_penalty,
|
|
53
|
+
"repetition_penalty": args.repetition_penalty,
|
|
54
|
+
}
|
|
55
|
+
if args.benchmark_type == "ort":
|
|
56
|
+
# convert_to_onnx export or ONNX E2E solution created by Olive
|
|
57
|
+
for k, v in inputs.items():
|
|
58
|
+
inputs[k] = np.array([v], dtype=np.float32 if "penalty" in k else np.int32)
|
|
59
|
+
if args.has_decoder_input_ids:
|
|
60
|
+
inputs["decoder_input_ids"] = np.array([args.decoder_input_ids], dtype=np.int32)
|
|
61
|
+
if args.has_logits_processor:
|
|
62
|
+
inputs["logits_processor"] = np.array([args.logits_processor], dtype=np.int32)
|
|
63
|
+
if args.has_temperature:
|
|
64
|
+
inputs["temperature"] = np.array([args.temperature], dtype=np.float32)
|
|
65
|
+
|
|
66
|
+
# Measure time taken to load audio file
|
|
67
|
+
logger.info(f"Load audio: {args.audio_path}")
|
|
68
|
+
load_audio_fn = lambda onnx_e2e: load_via_numpy() if onnx_e2e else load_via_ffmpeg() # noqa: E731
|
|
69
|
+
time_fn(args, load_audio_fn, args.has_audio_stream)
|
|
70
|
+
audio_data = load_audio_fn(args.has_audio_stream)
|
|
71
|
+
|
|
72
|
+
if args.has_audio_stream:
|
|
73
|
+
# ONNX E2E solution created by Olive
|
|
74
|
+
inputs["audio_stream"] = audio_data
|
|
75
|
+
return inputs
|
|
76
|
+
|
|
77
|
+
# Measure time taken to get input features
|
|
78
|
+
logger.info("Feature extraction: ")
|
|
79
|
+
return_type = "np" if args.benchmark_type == "ort" else "pt"
|
|
80
|
+
processor_fn = lambda audio: args.processor.feature_extractor( # noqa: E731
|
|
81
|
+
[audio], return_tensors=return_type, sampling_rate=args.sampling_rate
|
|
82
|
+
).input_features
|
|
83
|
+
time_fn(args, processor_fn, audio_data)
|
|
84
|
+
input_features = processor_fn(audio_data)
|
|
85
|
+
|
|
86
|
+
if args.benchmark_type == "ort":
|
|
87
|
+
# convert_to_onnx export
|
|
88
|
+
inputs["input_features"] = input_features
|
|
89
|
+
return inputs
|
|
90
|
+
|
|
91
|
+
inputs["inputs"] = input_features.to(
|
|
92
|
+
dtype=torch.float16 if args.use_fp16 else torch.float32, device=args.target_device
|
|
93
|
+
)
|
|
94
|
+
inputs["no_repeat_ngram_size"] = args.no_repeat_ngram_size
|
|
95
|
+
inputs["early_stopping"] = True
|
|
96
|
+
inputs["use_cache"] = True
|
|
97
|
+
|
|
98
|
+
if args.decoder_input_ids:
|
|
99
|
+
inputs["forced_decoder_ids"] = args.decoder_input_ids
|
|
100
|
+
|
|
101
|
+
return inputs
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def get_model(args: argparse.Namespace):
|
|
105
|
+
model, sess_options = None, None
|
|
106
|
+
start_time, end_time = None, None
|
|
107
|
+
|
|
108
|
+
# There are multiple sources that the model could come from:
|
|
109
|
+
# 1) Benchmark Whisper from Hugging Face
|
|
110
|
+
# 2) Benchmark Whisper ONNX model from Optimum export (without pre/post processing)
|
|
111
|
+
# 3) Benchmark Whisper ONNX E2E model from Olive (with pre/post processing)
|
|
112
|
+
|
|
113
|
+
if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile"}:
|
|
114
|
+
source = args.hf_pt_model_path if args.hf_pt_model_path else args.model_name
|
|
115
|
+
start_time = time.time()
|
|
116
|
+
model = AutoModelForSpeechSeq2Seq.from_pretrained(
|
|
117
|
+
source,
|
|
118
|
+
torch_dtype=torch.float16 if args.use_fp16 else torch.float32,
|
|
119
|
+
use_cache=True,
|
|
120
|
+
).to(args.target_device)
|
|
121
|
+
end_time = time.time()
|
|
122
|
+
|
|
123
|
+
if args.benchmark_type == "hf-pt-compile":
|
|
124
|
+
model = torch.compile(model)
|
|
125
|
+
|
|
126
|
+
elif args.benchmark_type in {"hf-ort", "ort"}:
|
|
127
|
+
sess_options = ort.SessionOptions()
|
|
128
|
+
sess_options.enable_profiling = args.profile
|
|
129
|
+
sess_options.register_custom_ops_library(get_library_path())
|
|
130
|
+
if args.verbose:
|
|
131
|
+
sess_options.log_verbosity_level = 1
|
|
132
|
+
sess_options.log_severity_level = 1
|
|
133
|
+
|
|
134
|
+
else:
|
|
135
|
+
raise Exception(f"Cannot recognize {args.benchmark_type}")
|
|
136
|
+
|
|
137
|
+
if args.benchmark_type == "hf-ort":
|
|
138
|
+
# Optimum export
|
|
139
|
+
provider = args.execution_provider[0] if type(args.execution_provider) is tuple else args.execution_provider
|
|
140
|
+
provider_options = args.execution_provider[1] if type(args.execution_provider) is tuple else None
|
|
141
|
+
|
|
142
|
+
start_time = time.time()
|
|
143
|
+
model = ORTModelForSpeechSeq2Seq.from_pretrained(
|
|
144
|
+
args.hf_ort_dir_path,
|
|
145
|
+
provider=provider,
|
|
146
|
+
provider_options=provider_options,
|
|
147
|
+
session_options=sess_options,
|
|
148
|
+
use_io_binding=True, # Avoid memory copy overhead
|
|
149
|
+
)
|
|
150
|
+
end_time = time.time()
|
|
151
|
+
|
|
152
|
+
if args.benchmark_type == "ort":
|
|
153
|
+
# convert_to_onnx.py export
|
|
154
|
+
logger.info(f"Loading model from {args.ort_model_path}")
|
|
155
|
+
start_time = time.time()
|
|
156
|
+
model = ort.InferenceSession(
|
|
157
|
+
args.ort_model_path,
|
|
158
|
+
sess_options,
|
|
159
|
+
providers=[args.execution_provider],
|
|
160
|
+
)
|
|
161
|
+
end_time = time.time()
|
|
162
|
+
|
|
163
|
+
logger.info(f"Loaded model in {end_time - start_time} s")
|
|
164
|
+
|
|
165
|
+
return model
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
def time_fn(args, fn, inputs):
|
|
169
|
+
warmup_inputs = inputs[0] if type(inputs) is tuple else inputs
|
|
170
|
+
benchmark_inputs = inputs[1] if type(inputs) is tuple else inputs
|
|
171
|
+
torch_device = torch.device(args.target_device)
|
|
172
|
+
|
|
173
|
+
# Warm up
|
|
174
|
+
warmup_range = (
|
|
175
|
+
range(args.warmup_runs)
|
|
176
|
+
if args.benchmark_type == "ort"
|
|
177
|
+
else trange(args.warmup_runs, file=sys.stdout, desc="Warm up")
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
if args.verbose:
|
|
181
|
+
outputs = fn(warmup_inputs)
|
|
182
|
+
logger.info(outputs)
|
|
183
|
+
|
|
184
|
+
for _ in warmup_range:
|
|
185
|
+
fn(warmup_inputs)
|
|
186
|
+
|
|
187
|
+
# Benchmark
|
|
188
|
+
if args.device != "cpu":
|
|
189
|
+
torch.cuda.synchronize(torch_device)
|
|
190
|
+
start_time = time.time()
|
|
191
|
+
|
|
192
|
+
bench_range = (
|
|
193
|
+
range(args.num_runs)
|
|
194
|
+
if args.benchmark_type == "ort"
|
|
195
|
+
else trange(args.num_runs, file=sys.stdout, desc="Benchmark")
|
|
196
|
+
)
|
|
197
|
+
for _ in bench_range:
|
|
198
|
+
fn(benchmark_inputs)
|
|
199
|
+
|
|
200
|
+
if args.device != "cpu":
|
|
201
|
+
torch.cuda.synchronize(torch_device)
|
|
202
|
+
end_time = time.time()
|
|
203
|
+
|
|
204
|
+
# Newline print after trange in order to print metrics on new lines without progress bar on same line
|
|
205
|
+
if args.benchmark_type != "ort":
|
|
206
|
+
logger.info("")
|
|
207
|
+
|
|
208
|
+
batch_size = 1
|
|
209
|
+
latency = (end_time - start_time) / args.num_runs
|
|
210
|
+
throughput = batch_size / latency
|
|
211
|
+
|
|
212
|
+
logger.info(f"Latency: {latency} s")
|
|
213
|
+
logger.info(f"Throughput: {throughput} qps")
|
|
214
|
+
return
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
def profile_fn(args, fn, inputs, inputs_type):
|
|
218
|
+
# Filename prefix format:
|
|
219
|
+
# "<benchmark-type>-<precision>-<device>_<inference-step>_<inputs-type>_<current-time>"
|
|
220
|
+
prefix = f"{args.benchmark_type.lower()}-{args.precision}-{args.device}_{fn.__name__.replace('_', '-')}_{inputs_type}_{datetime.datetime.now():%Y-%m-%d_%H:%M:%S}"
|
|
221
|
+
filename = None
|
|
222
|
+
|
|
223
|
+
if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile"}:
|
|
224
|
+
# Profile PyTorch kernels
|
|
225
|
+
with profile( # noqa: SIM117
|
|
226
|
+
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True, profile_memory=True
|
|
227
|
+
) as prof:
|
|
228
|
+
with record_function("model_inference"):
|
|
229
|
+
fn(inputs)
|
|
230
|
+
prof_data = prof.key_averages(group_by_stack_n=5).table(sort_by=args.pt_filter_by, row_limit=args.pt_num_rows)
|
|
231
|
+
|
|
232
|
+
filename = os.path.join(args.log_folder, f"{prefix}.log")
|
|
233
|
+
with open(filename, "w") as f:
|
|
234
|
+
f.write(prof_data)
|
|
235
|
+
|
|
236
|
+
else:
|
|
237
|
+
# Profile ORT kernels
|
|
238
|
+
fn(inputs)
|
|
239
|
+
|
|
240
|
+
# Set new log name for ORT profile log generated
|
|
241
|
+
filename = f"{prefix}.json"
|
|
242
|
+
|
|
243
|
+
return filename
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
def measure_fn(args, fn, inputs):
|
|
247
|
+
# Measure CPU usage
|
|
248
|
+
pid = os.getpid()
|
|
249
|
+
process = psutil.Process(pid)
|
|
250
|
+
process.cpu_percent(interval=0.1)
|
|
251
|
+
|
|
252
|
+
fn(inputs)
|
|
253
|
+
logger.info(f"CPU usage: {process.cpu_percent(interval=None)}%")
|
|
254
|
+
|
|
255
|
+
# Measure memory usage
|
|
256
|
+
gc.collect()
|
|
257
|
+
torch.cuda.empty_cache()
|
|
258
|
+
measure_memory(is_gpu=(args.device != "cpu"), func=lambda: fn(inputs), monitor_type=args.monitor_type)
|
|
259
|
+
|
|
260
|
+
# Flush output so memory usage is printed
|
|
261
|
+
sys.stdout.flush()
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
def run_hf_inference(args, inputs, model):
|
|
265
|
+
# Inference steps to measure
|
|
266
|
+
def get_pred_ids(inputs):
|
|
267
|
+
# Inference pass with predicted token ids generation
|
|
268
|
+
predicted_ids = model.generate(**inputs)
|
|
269
|
+
return predicted_ids
|
|
270
|
+
|
|
271
|
+
def gen_and_dec(inputs):
|
|
272
|
+
# Inference pass with generation and decoding
|
|
273
|
+
predicted_ids = get_pred_ids(inputs)
|
|
274
|
+
transcription = []
|
|
275
|
+
for _ in range(args.num_return_sequences):
|
|
276
|
+
transcription.append(args.processor.batch_decode(predicted_ids, skip_special_tokens=True)[0])
|
|
277
|
+
return predicted_ids, transcription
|
|
278
|
+
|
|
279
|
+
# Examples of other inference steps that can be measured:
|
|
280
|
+
# To use, uncomment the function and assign it to `generate_fn`
|
|
281
|
+
|
|
282
|
+
# def get_logits(inputs):
|
|
283
|
+
# # Inference pass without decoding
|
|
284
|
+
# outputs = model(**inputs)
|
|
285
|
+
# return outputs
|
|
286
|
+
|
|
287
|
+
generate_fn = gen_and_dec
|
|
288
|
+
|
|
289
|
+
if args.benchmark_type == "hf-pt-compile":
|
|
290
|
+
# Run forward pass once with each set of inputs to process through Dynamo
|
|
291
|
+
generate_fn(inputs)
|
|
292
|
+
|
|
293
|
+
if args.profile:
|
|
294
|
+
new_logname = profile_fn(args, generate_fn, inputs, "gen-and-dec")
|
|
295
|
+
if args.benchmark_type == "hf-ort":
|
|
296
|
+
# Rename log files per model component and turn profiling off to stop appending to log
|
|
297
|
+
new_prefix = new_logname[: -len(".json")]
|
|
298
|
+
|
|
299
|
+
old_logname = model.encoder.session.end_profiling()
|
|
300
|
+
new_logname = new_prefix + "-encoder.json"
|
|
301
|
+
if os.path.isfile(old_logname):
|
|
302
|
+
logger.warning(f"Renaming {old_logname} to {new_logname}")
|
|
303
|
+
os.rename(old_logname, os.path.join(args.log_folder, new_logname))
|
|
304
|
+
|
|
305
|
+
old_logname = model.decoder.session.end_profiling()
|
|
306
|
+
new_logname = new_prefix + "-decoder.json"
|
|
307
|
+
if os.path.isfile(old_logname):
|
|
308
|
+
logger.warning(f"Renaming {old_logname} to {new_logname}")
|
|
309
|
+
os.rename(old_logname, os.path.join(args.log_folder, new_logname))
|
|
310
|
+
|
|
311
|
+
old_logname = model.decoder_with_past.session.end_profiling()
|
|
312
|
+
new_logname = new_prefix + "-decoder-with-past.json"
|
|
313
|
+
if os.path.isfile(old_logname):
|
|
314
|
+
logger.warning(f"Renaming {old_logname} to {new_logname}")
|
|
315
|
+
os.rename(old_logname, os.path.join(args.log_folder, new_logname))
|
|
316
|
+
|
|
317
|
+
return
|
|
318
|
+
|
|
319
|
+
# PyTorch evaluations
|
|
320
|
+
logger.info("\nEvaluating PyTorch...")
|
|
321
|
+
time_fn(args, generate_fn, inputs)
|
|
322
|
+
predicted_ids, transcription = generate_fn(inputs)
|
|
323
|
+
logger.info(f"Generated token length: {len(predicted_ids[0])} tokens")
|
|
324
|
+
logger.info(f"Transcription: {transcription[0]}")
|
|
325
|
+
measure_fn(args, generate_fn, inputs)
|
|
326
|
+
|
|
327
|
+
|
|
328
|
+
def run_ort_inference(args, inputs, model):
|
|
329
|
+
def prepare_ort_inputs(inputs, warmup=False):
|
|
330
|
+
# Check that all model inputs will be provided
|
|
331
|
+
model_inputs = {model_input.name for model_input in model.get_inputs()}
|
|
332
|
+
user_inputs = set(inputs.keys())
|
|
333
|
+
missing_inputs = model_inputs - user_inputs
|
|
334
|
+
if len(missing_inputs):
|
|
335
|
+
logger.error(f"The following model inputs are missing: {missing_inputs}")
|
|
336
|
+
raise Exception("There are missing inputs to the model. Please add them and try again.")
|
|
337
|
+
|
|
338
|
+
# Remove unnecessary inputs from model inputs
|
|
339
|
+
unnecessary_inputs = user_inputs - model_inputs
|
|
340
|
+
if len(unnecessary_inputs):
|
|
341
|
+
for unnecessary_input in unnecessary_inputs:
|
|
342
|
+
logger.info(f"Removing unnecessary input '{unnecessary_input}' from user provided inputs")
|
|
343
|
+
del inputs[unnecessary_input]
|
|
344
|
+
|
|
345
|
+
# Add IO bindings for non-CPU execution providers
|
|
346
|
+
if args.device != "cpu":
|
|
347
|
+
io_binding = model.io_binding()
|
|
348
|
+
for k, v in inputs.items():
|
|
349
|
+
io_binding.bind_cpu_input(k, v)
|
|
350
|
+
for output in model.get_outputs():
|
|
351
|
+
io_binding.bind_output(output.name, device_type=args.device, device_id=args.device_id)
|
|
352
|
+
return io_binding
|
|
353
|
+
|
|
354
|
+
return inputs
|
|
355
|
+
|
|
356
|
+
def with_io_binding(io_binding):
|
|
357
|
+
# Inference pass with IO binding
|
|
358
|
+
model.run_with_iobinding(io_binding)
|
|
359
|
+
return io_binding
|
|
360
|
+
|
|
361
|
+
def without_io_binding(inputs):
|
|
362
|
+
# Inference pass without IO binding
|
|
363
|
+
outputs = model.run(None, inputs)
|
|
364
|
+
return outputs
|
|
365
|
+
|
|
366
|
+
def handle_output(output):
|
|
367
|
+
if args.eos_token_id in output:
|
|
368
|
+
first_end = np.where(output == args.eos_token_id)[0][0]
|
|
369
|
+
return output[: first_end + 1]
|
|
370
|
+
|
|
371
|
+
return output
|
|
372
|
+
|
|
373
|
+
generate_fn = with_io_binding if args.device != "cpu" else without_io_binding
|
|
374
|
+
ort_inputs = prepare_ort_inputs(inputs)
|
|
375
|
+
|
|
376
|
+
if args.profile:
|
|
377
|
+
new_logname = profile_fn(args, generate_fn, ort_inputs, "e2e")
|
|
378
|
+
|
|
379
|
+
# Turn profiling off to stop appending to log file
|
|
380
|
+
old_logname = model.end_profiling()
|
|
381
|
+
logger.warning(f"Renaming {old_logname} to {new_logname}")
|
|
382
|
+
os.rename(old_logname, os.path.join(args.log_folder, new_logname))
|
|
383
|
+
|
|
384
|
+
return
|
|
385
|
+
|
|
386
|
+
# ORT evaluation
|
|
387
|
+
logger.info("\nEvaluating ONNX Runtime...")
|
|
388
|
+
ort_evaluate_inputs = ort_inputs
|
|
389
|
+
|
|
390
|
+
time_fn(args, generate_fn, ort_evaluate_inputs)
|
|
391
|
+
ort_outputs = generate_fn(ort_inputs)
|
|
392
|
+
if args.device != "cpu":
|
|
393
|
+
ort_outputs = ort_outputs.copy_outputs_to_cpu()
|
|
394
|
+
ort_outputs = ort_outputs[0]
|
|
395
|
+
|
|
396
|
+
if args.has_audio_stream:
|
|
397
|
+
# ONNX E2E model from Olive produces transcribed output
|
|
398
|
+
logger.info(f"Transcription: {ort_outputs[0][0]}")
|
|
399
|
+
else:
|
|
400
|
+
# convert_to_onnx model produces generated ids
|
|
401
|
+
actual_output = handle_output(ort_outputs[0][0])
|
|
402
|
+
logger.info(f"Generated token length: {len(actual_output)} tokens")
|
|
403
|
+
transcription = args.processor.batch_decode(ort_outputs[0], skip_special_tokens=True)[0]
|
|
404
|
+
# print to stdout as the output for comparison
|
|
405
|
+
print(f"{transcription}")
|
|
406
|
+
|
|
407
|
+
measure_fn(args, generate_fn, ort_inputs)
|
|
408
|
+
|
|
409
|
+
|
|
410
|
+
def run_inference(args, inputs, model):
|
|
411
|
+
if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile", "hf-ort"}:
|
|
412
|
+
run_hf_inference(args, inputs, model)
|
|
413
|
+
elif args.benchmark_type == "ort":
|
|
414
|
+
run_ort_inference(args, inputs, model)
|
|
415
|
+
else:
|
|
416
|
+
raise Exception(f"Cannot recognize {args.benchmark_type}")
|
|
417
|
+
|
|
418
|
+
|
|
419
|
+
def parse_args():
|
|
420
|
+
parser = argparse.ArgumentParser()
|
|
421
|
+
|
|
422
|
+
parser.add_argument(
|
|
423
|
+
"-bt",
|
|
424
|
+
"--benchmark-type",
|
|
425
|
+
type=str,
|
|
426
|
+
required=True,
|
|
427
|
+
choices=["hf-pt-eager", "hf-pt-compile", "hf-ort", "ort"],
|
|
428
|
+
)
|
|
429
|
+
|
|
430
|
+
parser.add_argument(
|
|
431
|
+
"-m",
|
|
432
|
+
"--model-name",
|
|
433
|
+
type=str,
|
|
434
|
+
required=True,
|
|
435
|
+
help="Hugging Face name of model (e.g. 'openai/whisper-large-v2')",
|
|
436
|
+
)
|
|
437
|
+
parser.add_argument(
|
|
438
|
+
"-p",
|
|
439
|
+
"--precision",
|
|
440
|
+
type=str,
|
|
441
|
+
required=True,
|
|
442
|
+
default="fp32",
|
|
443
|
+
choices=["int4", "int8", "fp16", "fp32"],
|
|
444
|
+
help="Precision for model. For ONNX models, the model's precision should be set before running this script.",
|
|
445
|
+
)
|
|
446
|
+
|
|
447
|
+
parser.add_argument(
|
|
448
|
+
"--hf-pt-model-path",
|
|
449
|
+
type=str,
|
|
450
|
+
default="",
|
|
451
|
+
help="Path to directory containing all PyTorch files (e.g. tokenizer, PyTorch model)",
|
|
452
|
+
)
|
|
453
|
+
parser.add_argument(
|
|
454
|
+
"--hf-ort-dir-path",
|
|
455
|
+
type=str,
|
|
456
|
+
default="",
|
|
457
|
+
help="Path to directory containing all ONNX files (e.g. tokenizer, encoder, decoder, decoder_with_past)",
|
|
458
|
+
)
|
|
459
|
+
parser.add_argument(
|
|
460
|
+
"--ort-model-path",
|
|
461
|
+
type=str,
|
|
462
|
+
default="",
|
|
463
|
+
help="Path to ONNX model",
|
|
464
|
+
)
|
|
465
|
+
|
|
466
|
+
# Args for running and evaluating the model
|
|
467
|
+
parser.add_argument("-a", "--audio-path", type=str, required=True, help="Path to audio file for E2E evaluation")
|
|
468
|
+
parser.add_argument(
|
|
469
|
+
"-d",
|
|
470
|
+
"--device",
|
|
471
|
+
type=str,
|
|
472
|
+
default="cuda" if torch.cuda.is_available() else "cpu",
|
|
473
|
+
choices=["cpu", "cuda"],
|
|
474
|
+
)
|
|
475
|
+
parser.add_argument("-id", "--device-id", type=int, default=0)
|
|
476
|
+
parser.add_argument("-w", "--warmup-runs", type=int, default=5)
|
|
477
|
+
parser.add_argument("-n", "--num-runs", type=int, default=10)
|
|
478
|
+
parser.add_argument("--seed", type=int, default=2)
|
|
479
|
+
|
|
480
|
+
# Optional args:
|
|
481
|
+
parser.add_argument("--sampling-rate", type=int, default=16000, help="Sampling rate for audio (in Hz)")
|
|
482
|
+
|
|
483
|
+
# Args for decoding logic
|
|
484
|
+
# Required args:
|
|
485
|
+
parser.add_argument("--max-length", type=int, default=448)
|
|
486
|
+
parser.add_argument("--min-length", type=int, default=0)
|
|
487
|
+
parser.add_argument("--num-beams", type=int, default=1)
|
|
488
|
+
parser.add_argument("--num-return-sequences", type=int, default=1)
|
|
489
|
+
parser.add_argument("--length-penalty", type=float, default=1.0)
|
|
490
|
+
parser.add_argument("--repetition-penalty", type=float, default=1.0)
|
|
491
|
+
parser.add_argument("--no-repeat-ngram-size", type=int, default=3)
|
|
492
|
+
|
|
493
|
+
# Optional args for E2E solution:
|
|
494
|
+
parser.add_argument(
|
|
495
|
+
"--decoder-input-ids",
|
|
496
|
+
type=str,
|
|
497
|
+
default="[]",
|
|
498
|
+
help="The forced decoder ids for generation. Format is [start token, timestamp token, language token, task token]. Default is [start token]. See `decoder_input_ids` in https://github.com/microsoft/Olive/tree/main/examples/whisper for details.",
|
|
499
|
+
)
|
|
500
|
+
parser.add_argument(
|
|
501
|
+
"--logits-processor",
|
|
502
|
+
type=int,
|
|
503
|
+
default=1,
|
|
504
|
+
help="Whether to use timestamps logits processor or not (0 for false, 1 for true).",
|
|
505
|
+
)
|
|
506
|
+
parser.add_argument(
|
|
507
|
+
"--temperature",
|
|
508
|
+
type=float,
|
|
509
|
+
default=1.0,
|
|
510
|
+
help="Temperature value for generation.",
|
|
511
|
+
)
|
|
512
|
+
|
|
513
|
+
# Args for accessing detailed info
|
|
514
|
+
parser.add_argument("--profile", default=False, action="store_true")
|
|
515
|
+
parser.add_argument(
|
|
516
|
+
"--pt-filter-by", type=str, default="self_cpu_time_total", help="What to filter PyTorch profiler by"
|
|
517
|
+
)
|
|
518
|
+
parser.add_argument("--pt-num-rows", type=int, default=1000, help="Number of rows for PyTorch profiler to display")
|
|
519
|
+
parser.add_argument("--verbose", default=False, action="store_true")
|
|
520
|
+
parser.add_argument("--log-folder", type=str, default=os.path.join("."), help="Folder to cache log files")
|
|
521
|
+
|
|
522
|
+
args = parser.parse_args()
|
|
523
|
+
|
|
524
|
+
# Set seed properties
|
|
525
|
+
np.random.seed(args.seed)
|
|
526
|
+
torch.manual_seed(args.seed)
|
|
527
|
+
|
|
528
|
+
args.monitor_type = args.device
|
|
529
|
+
# Set runtime properties
|
|
530
|
+
if "ort" in args.benchmark_type:
|
|
531
|
+
args.execution_provider = f"{args.device.upper()}ExecutionProvider"
|
|
532
|
+
if args.execution_provider == "CUDAExecutionProvider":
|
|
533
|
+
args.execution_provider = (args.execution_provider, {"device_id": args.device_id})
|
|
534
|
+
|
|
535
|
+
# Check that model paths have been specified for any benchmarking with ORT
|
|
536
|
+
if args.benchmark_type == "hf-ort":
|
|
537
|
+
assert args.hf_ort_dir_path, "Please specify a path to `--hf-ort-dir-path`"
|
|
538
|
+
if args.benchmark_type == "ort":
|
|
539
|
+
assert args.ort_model_path, "Please specify a path to `--ort-model-path`"
|
|
540
|
+
|
|
541
|
+
# Convert decoder_input_ids string to list of ids
|
|
542
|
+
# (e.g. "[1, 50257]" for Hugging Face or "[50257]" for ORT)
|
|
543
|
+
args.decoder_input_ids = ast.literal_eval(args.decoder_input_ids)
|
|
544
|
+
|
|
545
|
+
return args
|
|
546
|
+
|
|
547
|
+
|
|
548
|
+
def main():
|
|
549
|
+
args = parse_args()
|
|
550
|
+
setup_logger(args.verbose)
|
|
551
|
+
logger.info(args.__dict__)
|
|
552
|
+
torch.backends.cudnn.benchmark = True
|
|
553
|
+
|
|
554
|
+
config = WhisperConfig.from_pretrained(args.model_name)
|
|
555
|
+
processor = WhisperProcessor.from_pretrained(args.model_name)
|
|
556
|
+
target_device = f"cuda:{args.device_id}" if args.device != "cpu" else args.device
|
|
557
|
+
use_fp16 = args.precision == "fp16" or (args.precision in {"int8", "int4"} and args.device != "cpu")
|
|
558
|
+
|
|
559
|
+
setattr(args, "processor", processor) # noqa: B010
|
|
560
|
+
setattr(args, "target_device", target_device) # noqa: B010
|
|
561
|
+
setattr(args, "use_fp16", use_fp16) # noqa: B010
|
|
562
|
+
setattr(args, "has_audio_stream", False) # noqa: B010
|
|
563
|
+
setattr(args, "eos_token_id", config.eos_token_id) # noqa: B010
|
|
564
|
+
|
|
565
|
+
logger.info(f"Forced decoder prompt ids: {args.decoder_input_ids}")
|
|
566
|
+
|
|
567
|
+
# Measure cost to transcribe audio
|
|
568
|
+
model = get_model(args)
|
|
569
|
+
if args.benchmark_type == "ort":
|
|
570
|
+
# Check for optional inputs that could have been added during export
|
|
571
|
+
ort_model_inputs = {model_input.name for model_input in model.get_inputs()}
|
|
572
|
+
args.has_audio_stream = "audio_stream" in ort_model_inputs
|
|
573
|
+
setattr(args, "has_decoder_input_ids", "decoder_input_ids" in ort_model_inputs) # noqa: B010
|
|
574
|
+
setattr(args, "has_logits_processor", "logits_processor" in ort_model_inputs) # noqa: B010
|
|
575
|
+
setattr(args, "has_temperature", "temperature" in ort_model_inputs) # noqa: B010
|
|
576
|
+
|
|
577
|
+
if args.decoder_input_ids == []:
|
|
578
|
+
args.decoder_input_ids = [config.decoder_start_token_id]
|
|
579
|
+
|
|
580
|
+
inputs = get_inputs(args)
|
|
581
|
+
run_inference(args, inputs, model)
|
|
582
|
+
|
|
583
|
+
|
|
584
|
+
if __name__ == "__main__":
|
|
585
|
+
main()
|