onnxruntime-directml 1.20.0__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxruntime/LICENSE +21 -0
- onnxruntime/Privacy.md +21 -0
- onnxruntime/ThirdPartyNotices.txt +6508 -0
- onnxruntime/__init__.py +78 -0
- onnxruntime/backend/__init__.py +6 -0
- onnxruntime/backend/backend.py +174 -0
- onnxruntime/backend/backend_rep.py +53 -0
- onnxruntime/capi/DirectML.dll +0 -0
- onnxruntime/capi/__init__.py +4 -0
- onnxruntime/capi/_ld_preload.py +7 -0
- onnxruntime/capi/_pybind_state.py +33 -0
- onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
- onnxruntime/capi/onnxruntime.dll +0 -0
- onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
- onnxruntime/capi/onnxruntime_inference_collection.py +1108 -0
- onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
- onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
- onnxruntime/capi/onnxruntime_validation.py +150 -0
- onnxruntime/capi/version_info.py +2 -0
- onnxruntime/datasets/__init__.py +17 -0
- onnxruntime/datasets/logreg_iris.onnx +0 -0
- onnxruntime/datasets/mul_1.onnx +0 -0
- onnxruntime/datasets/sigmoid.onnx +13 -0
- onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
- onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
- onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
- onnxruntime/quantization/__init__.py +16 -0
- onnxruntime/quantization/base_quantizer.py +532 -0
- onnxruntime/quantization/calibrate.py +1245 -0
- onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
- onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
- onnxruntime/quantization/execution_providers/qnn/preprocess.py +307 -0
- onnxruntime/quantization/execution_providers/qnn/quant_config.py +387 -0
- onnxruntime/quantization/fusions/__init__.py +3 -0
- onnxruntime/quantization/fusions/fusion.py +311 -0
- onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
- onnxruntime/quantization/fusions/fusion_layernorm.py +135 -0
- onnxruntime/quantization/matmul_4bits_quantizer.py +1480 -0
- onnxruntime/quantization/matmul_bnb4_quantizer.py +240 -0
- onnxruntime/quantization/onnx_model.py +580 -0
- onnxruntime/quantization/onnx_quantizer.py +1008 -0
- onnxruntime/quantization/operators/__init__.py +2 -0
- onnxruntime/quantization/operators/activation.py +119 -0
- onnxruntime/quantization/operators/argmax.py +18 -0
- onnxruntime/quantization/operators/attention.py +73 -0
- onnxruntime/quantization/operators/base_operator.py +26 -0
- onnxruntime/quantization/operators/binary_op.py +72 -0
- onnxruntime/quantization/operators/concat.py +62 -0
- onnxruntime/quantization/operators/conv.py +258 -0
- onnxruntime/quantization/operators/direct_q8.py +78 -0
- onnxruntime/quantization/operators/embed_layernorm.py +121 -0
- onnxruntime/quantization/operators/gather.py +64 -0
- onnxruntime/quantization/operators/gavgpool.py +62 -0
- onnxruntime/quantization/operators/gemm.py +166 -0
- onnxruntime/quantization/operators/lstm.py +117 -0
- onnxruntime/quantization/operators/matmul.py +231 -0
- onnxruntime/quantization/operators/maxpool.py +34 -0
- onnxruntime/quantization/operators/norm.py +40 -0
- onnxruntime/quantization/operators/pad.py +100 -0
- onnxruntime/quantization/operators/pooling.py +67 -0
- onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
- onnxruntime/quantization/operators/resize.py +34 -0
- onnxruntime/quantization/operators/softmax.py +74 -0
- onnxruntime/quantization/operators/split.py +63 -0
- onnxruntime/quantization/operators/where.py +87 -0
- onnxruntime/quantization/preprocess.py +141 -0
- onnxruntime/quantization/qdq_loss_debug.py +389 -0
- onnxruntime/quantization/qdq_quantizer.py +1187 -0
- onnxruntime/quantization/quant_utils.py +891 -0
- onnxruntime/quantization/quantize.py +748 -0
- onnxruntime/quantization/registry.py +106 -0
- onnxruntime/quantization/shape_inference.py +187 -0
- onnxruntime/quantization/tensor_quant_overrides.py +516 -0
- onnxruntime/tools/__init__.py +10 -0
- onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
- onnxruntime/tools/convert_onnx_models_to_ort.py +377 -0
- onnxruntime/tools/file_utils.py +46 -0
- onnxruntime/tools/logger.py +11 -0
- onnxruntime/tools/make_dynamic_shape_fixed.py +72 -0
- onnxruntime/tools/mobile_helpers/__init__.py +0 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +33 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
- onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
- onnxruntime/tools/mobile_helpers/usability_checker.py +739 -0
- onnxruntime/tools/offline_tuning.py +169 -0
- onnxruntime/tools/onnx_model_utils.py +413 -0
- onnxruntime/tools/onnx_randomizer.py +85 -0
- onnxruntime/tools/onnxruntime_test.py +164 -0
- onnxruntime/tools/optimize_onnx_model.py +55 -0
- onnxruntime/tools/ort_format_model/__init__.py +25 -0
- onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +663 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
- onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
- onnxruntime/tools/ort_format_model/types.py +84 -0
- onnxruntime/tools/ort_format_model/utils.py +62 -0
- onnxruntime/tools/pytorch_export_contrib_ops.py +108 -0
- onnxruntime/tools/pytorch_export_helpers.py +131 -0
- onnxruntime/tools/qdq_helpers/__init__.py +0 -0
- onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
- onnxruntime/tools/reduced_build_config_parser.py +202 -0
- onnxruntime/tools/symbolic_shape_infer.py +3016 -0
- onnxruntime/tools/update_onnx_opset.py +31 -0
- onnxruntime/transformers/__init__.py +8 -0
- onnxruntime/transformers/affinity_helper.py +40 -0
- onnxruntime/transformers/benchmark.py +944 -0
- onnxruntime/transformers/benchmark_helper.py +646 -0
- onnxruntime/transformers/bert_perf_test.py +634 -0
- onnxruntime/transformers/bert_test_data.py +642 -0
- onnxruntime/transformers/compare_bert_results.py +246 -0
- onnxruntime/transformers/constants.py +47 -0
- onnxruntime/transformers/convert_generation.py +3124 -0
- onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
- onnxruntime/transformers/convert_to_packing_mode.py +387 -0
- onnxruntime/transformers/dynamo_onnx_helper.py +104 -0
- onnxruntime/transformers/float16.py +501 -0
- onnxruntime/transformers/fusion_attention.py +1235 -0
- onnxruntime/transformers/fusion_attention_clip.py +257 -0
- onnxruntime/transformers/fusion_attention_sam2.py +534 -0
- onnxruntime/transformers/fusion_attention_unet.py +1304 -0
- onnxruntime/transformers/fusion_attention_vae.py +301 -0
- onnxruntime/transformers/fusion_bart_attention.py +640 -0
- onnxruntime/transformers/fusion_base.py +137 -0
- onnxruntime/transformers/fusion_bias_add.py +58 -0
- onnxruntime/transformers/fusion_biasgelu.py +66 -0
- onnxruntime/transformers/fusion_biassplitgelu.py +111 -0
- onnxruntime/transformers/fusion_conformer_attention.py +143 -0
- onnxruntime/transformers/fusion_embedlayer.py +811 -0
- onnxruntime/transformers/fusion_fastgelu.py +360 -0
- onnxruntime/transformers/fusion_gelu.py +259 -0
- onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
- onnxruntime/transformers/fusion_gemmfastgelu.py +122 -0
- onnxruntime/transformers/fusion_gpt_attention.py +546 -0
- onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
- onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
- onnxruntime/transformers/fusion_group_norm.py +179 -0
- onnxruntime/transformers/fusion_layernorm.py +465 -0
- onnxruntime/transformers/fusion_nhwc_conv.py +100 -0
- onnxruntime/transformers/fusion_options.py +340 -0
- onnxruntime/transformers/fusion_qordered_attention.py +421 -0
- onnxruntime/transformers/fusion_qordered_gelu.py +119 -0
- onnxruntime/transformers/fusion_qordered_layernorm.py +123 -0
- onnxruntime/transformers/fusion_qordered_matmul.py +217 -0
- onnxruntime/transformers/fusion_quickgelu.py +74 -0
- onnxruntime/transformers/fusion_reshape.py +173 -0
- onnxruntime/transformers/fusion_rotary_attention.py +1592 -0
- onnxruntime/transformers/fusion_shape.py +110 -0
- onnxruntime/transformers/fusion_simplified_layernorm.py +159 -0
- onnxruntime/transformers/fusion_skip_group_norm.py +255 -0
- onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
- onnxruntime/transformers/fusion_transpose.py +168 -0
- onnxruntime/transformers/fusion_utils.py +307 -0
- onnxruntime/transformers/huggingface_models.py +167 -0
- onnxruntime/transformers/import_utils.py +20 -0
- onnxruntime/transformers/io_binding_helper.py +442 -0
- onnxruntime/transformers/large_model_exporter.py +395 -0
- onnxruntime/transformers/machine_info.py +221 -0
- onnxruntime/transformers/metrics.py +164 -0
- onnxruntime/transformers/models/bart/__init__.py +12 -0
- onnxruntime/transformers/models/bart/export.py +98 -0
- onnxruntime/transformers/models/bert/__init__.py +12 -0
- onnxruntime/transformers/models/bert/eval_squad.py +329 -0
- onnxruntime/transformers/models/gpt2/__init__.py +12 -0
- onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
- onnxruntime/transformers/models/gpt2/convert_to_onnx.py +561 -0
- onnxruntime/transformers/models/gpt2/gpt2_helper.py +1032 -0
- onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
- onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
- onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
- onnxruntime/transformers/models/llama/__init__.py +12 -0
- onnxruntime/transformers/models/llama/benchmark.py +703 -0
- onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
- onnxruntime/transformers/models/llama/benchmark_e2e.py +606 -0
- onnxruntime/transformers/models/llama/convert_to_onnx.py +1027 -0
- onnxruntime/transformers/models/llama/dist_settings.py +57 -0
- onnxruntime/transformers/models/llama/llama_inputs.py +503 -0
- onnxruntime/transformers/models/llama/llama_parity.py +309 -0
- onnxruntime/transformers/models/llama/llama_torch.py +47 -0
- onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
- onnxruntime/transformers/models/longformer/__init__.py +12 -0
- onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
- onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
- onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
- onnxruntime/transformers/models/longformer/longformer_helper.py +77 -0
- onnxruntime/transformers/models/phi2/__init__.py +12 -0
- onnxruntime/transformers/models/phi2/convert_to_onnx.py +576 -0
- onnxruntime/transformers/models/phi2/inference_example.py +414 -0
- onnxruntime/transformers/models/sam2/__init__.py +12 -0
- onnxruntime/transformers/models/sam2/benchmark_sam2.py +625 -0
- onnxruntime/transformers/models/sam2/convert_to_onnx.py +260 -0
- onnxruntime/transformers/models/sam2/image_decoder.py +273 -0
- onnxruntime/transformers/models/sam2/image_encoder.py +186 -0
- onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
- onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
- onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
- onnxruntime/transformers/models/sam2/sam2_demo.py +322 -0
- onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +280 -0
- onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
- onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark.py +1429 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +102 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +268 -0
- onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1319 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1181 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder.py +296 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +388 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
- onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +350 -0
- onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
- onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
- onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
- onnxruntime/transformers/models/t5/__init__.py +12 -0
- onnxruntime/transformers/models/t5/convert_to_onnx.py +278 -0
- onnxruntime/transformers/models/t5/past_helper.py +150 -0
- onnxruntime/transformers/models/t5/t5_decoder.py +438 -0
- onnxruntime/transformers/models/t5/t5_encoder.py +171 -0
- onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +299 -0
- onnxruntime/transformers/models/t5/t5_helper.py +272 -0
- onnxruntime/transformers/models/whisper/__init__.py +12 -0
- onnxruntime/transformers/models/whisper/benchmark.py +610 -0
- onnxruntime/transformers/models/whisper/benchmark_all.py +528 -0
- onnxruntime/transformers/models/whisper/convert_to_onnx.py +536 -0
- onnxruntime/transformers/models/whisper/whisper_chain.py +329 -0
- onnxruntime/transformers/models/whisper/whisper_decoder.py +402 -0
- onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
- onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +306 -0
- onnxruntime/transformers/models/whisper/whisper_helper.py +524 -0
- onnxruntime/transformers/models/whisper/whisper_openai_helper.py +84 -0
- onnxruntime/transformers/onnx_exporter.py +717 -0
- onnxruntime/transformers/onnx_model.py +1569 -0
- onnxruntime/transformers/onnx_model_bart.py +142 -0
- onnxruntime/transformers/onnx_model_bert.py +481 -0
- onnxruntime/transformers/onnx_model_bert_keras.py +475 -0
- onnxruntime/transformers/onnx_model_bert_tf.py +589 -0
- onnxruntime/transformers/onnx_model_clip.py +40 -0
- onnxruntime/transformers/onnx_model_conformer.py +33 -0
- onnxruntime/transformers/onnx_model_gpt2.py +101 -0
- onnxruntime/transformers/onnx_model_phi.py +930 -0
- onnxruntime/transformers/onnx_model_sam2.py +138 -0
- onnxruntime/transformers/onnx_model_t5.py +791 -0
- onnxruntime/transformers/onnx_model_tnlr.py +227 -0
- onnxruntime/transformers/onnx_model_unet.py +259 -0
- onnxruntime/transformers/onnx_model_vae.py +43 -0
- onnxruntime/transformers/onnx_utils.py +55 -0
- onnxruntime/transformers/optimizer.py +612 -0
- onnxruntime/transformers/profiler.py +725 -0
- onnxruntime/transformers/quantize_helper.py +76 -0
- onnxruntime/transformers/shape_infer_helper.py +122 -0
- onnxruntime/transformers/shape_optimizer.py +401 -0
- onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
- onnxruntime_directml-1.20.0.dist-info/METADATA +187 -0
- onnxruntime_directml-1.20.0.dist-info/RECORD +305 -0
- onnxruntime_directml-1.20.0.dist-info/WHEEL +5 -0
- onnxruntime_directml-1.20.0.dist-info/entry_points.txt +2 -0
- onnxruntime_directml-1.20.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,536 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License. See License.txt in the project root for
|
|
4
|
+
# license information.
|
|
5
|
+
# --------------------------------------------------------------------------
|
|
6
|
+
|
|
7
|
+
import argparse
|
|
8
|
+
import copy
|
|
9
|
+
import logging
|
|
10
|
+
import os
|
|
11
|
+
|
|
12
|
+
import torch
|
|
13
|
+
from benchmark_helper import Precision, create_onnxruntime_session, prepare_environment, setup_logger
|
|
14
|
+
from whisper_chain import chain_model
|
|
15
|
+
from whisper_helper import PRETRAINED_WHISPER_MODELS, WhisperHelper
|
|
16
|
+
|
|
17
|
+
from onnxruntime import quantization
|
|
18
|
+
|
|
19
|
+
logger = logging.getLogger("")
|
|
20
|
+
|
|
21
|
+
PROVIDERS = {
|
|
22
|
+
"cpu": "CPUExecutionProvider",
|
|
23
|
+
"cuda": "CUDAExecutionProvider",
|
|
24
|
+
"rocm": "ROCMExecutionProvider",
|
|
25
|
+
}
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def parse_arguments(argv=None):
|
|
29
|
+
parser = argparse.ArgumentParser()
|
|
30
|
+
|
|
31
|
+
conversion_args = parser.add_argument_group("Conversion Process Args")
|
|
32
|
+
optional_inputs = parser.add_argument_group("Optional Inputs (for WhisperBeamSearch op)")
|
|
33
|
+
optional_outputs = parser.add_argument_group("Optional Outputs (for WhisperBeamSearch op)")
|
|
34
|
+
quant_args = parser.add_argument_group("INT8 Quantization Args")
|
|
35
|
+
|
|
36
|
+
#################################
|
|
37
|
+
# Conversion options for Whisper
|
|
38
|
+
#################################
|
|
39
|
+
|
|
40
|
+
conversion_args.add_argument(
|
|
41
|
+
"-m",
|
|
42
|
+
"--model_name_or_path",
|
|
43
|
+
required=False,
|
|
44
|
+
default=PRETRAINED_WHISPER_MODELS[0],
|
|
45
|
+
type=str,
|
|
46
|
+
help="Model path, or pretrained model name in the list: " + ", ".join(PRETRAINED_WHISPER_MODELS),
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
conversion_args.add_argument(
|
|
50
|
+
"--model_impl",
|
|
51
|
+
required=False,
|
|
52
|
+
default="hf",
|
|
53
|
+
choices=["hf", "openai"],
|
|
54
|
+
type=str,
|
|
55
|
+
help="Select implementation for export of encoder and decoder subgraphs",
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
conversion_args.add_argument(
|
|
59
|
+
"--cache_dir",
|
|
60
|
+
required=False,
|
|
61
|
+
type=str,
|
|
62
|
+
default=os.path.join(".", "cache_models"),
|
|
63
|
+
help="Directory to cache pre-trained models",
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
conversion_args.add_argument(
|
|
67
|
+
"--output",
|
|
68
|
+
required=False,
|
|
69
|
+
type=str,
|
|
70
|
+
default=os.path.join(".", "onnx_models"),
|
|
71
|
+
help="Output directory",
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
conversion_args.add_argument(
|
|
75
|
+
"-o",
|
|
76
|
+
"--optimize_onnx",
|
|
77
|
+
required=False,
|
|
78
|
+
action="store_true",
|
|
79
|
+
help="Use optimizer.py to optimize onnx model",
|
|
80
|
+
)
|
|
81
|
+
conversion_args.set_defaults(optimize_onnx=False)
|
|
82
|
+
|
|
83
|
+
conversion_args.add_argument(
|
|
84
|
+
"--use_gpu",
|
|
85
|
+
required=False,
|
|
86
|
+
action="store_true",
|
|
87
|
+
help="Use GPU for model inference",
|
|
88
|
+
)
|
|
89
|
+
conversion_args.set_defaults(use_gpu=False)
|
|
90
|
+
|
|
91
|
+
conversion_args.add_argument(
|
|
92
|
+
"-p",
|
|
93
|
+
"--precision",
|
|
94
|
+
required=False,
|
|
95
|
+
type=Precision,
|
|
96
|
+
default=Precision.FLOAT32,
|
|
97
|
+
choices=[Precision.FLOAT32, Precision.FLOAT16, Precision.INT8],
|
|
98
|
+
help="Precision of model to run. fp32 for full precision, fp16 for half precision, int8 for quantization",
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
conversion_args.add_argument(
|
|
102
|
+
"--use_int64_inputs",
|
|
103
|
+
required=False,
|
|
104
|
+
action="store_true",
|
|
105
|
+
help="Use int64 instead of int32 for input_ids and attention_mask.",
|
|
106
|
+
)
|
|
107
|
+
conversion_args.set_defaults(use_int64_inputs=False)
|
|
108
|
+
|
|
109
|
+
conversion_args.add_argument(
|
|
110
|
+
"--disable_auto_mixed_precision",
|
|
111
|
+
required=False,
|
|
112
|
+
action="store_true",
|
|
113
|
+
help="Use pure fp16 instead of mixed precision",
|
|
114
|
+
)
|
|
115
|
+
conversion_args.set_defaults(disable_auto_mixed_precision=False)
|
|
116
|
+
|
|
117
|
+
conversion_args.add_argument(
|
|
118
|
+
"-r",
|
|
119
|
+
"--provider",
|
|
120
|
+
required=False,
|
|
121
|
+
type=str,
|
|
122
|
+
default="cpu",
|
|
123
|
+
choices=list(PROVIDERS.keys()),
|
|
124
|
+
help="Provider to benchmark. Default is CPUExecutionProvider.",
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
conversion_args.add_argument(
|
|
128
|
+
"--verbose",
|
|
129
|
+
required=False,
|
|
130
|
+
action="store_true",
|
|
131
|
+
help="Enable verbose logging",
|
|
132
|
+
)
|
|
133
|
+
conversion_args.set_defaults(verbose=False)
|
|
134
|
+
|
|
135
|
+
conversion_args.add_argument(
|
|
136
|
+
"-e",
|
|
137
|
+
"--use_external_data_format",
|
|
138
|
+
required=False,
|
|
139
|
+
action="store_true",
|
|
140
|
+
help="Save weights in external file. Necessary for 'small', 'medium', and 'large' models. Optional for 'tiny' and 'base' models.",
|
|
141
|
+
)
|
|
142
|
+
conversion_args.set_defaults(use_external_data_format=False)
|
|
143
|
+
|
|
144
|
+
conversion_args.add_argument(
|
|
145
|
+
"-w",
|
|
146
|
+
"--overwrite",
|
|
147
|
+
required=False,
|
|
148
|
+
action="store_true",
|
|
149
|
+
help="Overwrite existing ONNX model",
|
|
150
|
+
)
|
|
151
|
+
conversion_args.set_defaults(overwrite=False)
|
|
152
|
+
|
|
153
|
+
conversion_args.add_argument(
|
|
154
|
+
"--separate_encoder_and_decoder_init",
|
|
155
|
+
required=False,
|
|
156
|
+
action="store_true",
|
|
157
|
+
help="Do not merge encoder and decoder init to initialize past KV caches. Output 3 instead of 2 ONNX models.",
|
|
158
|
+
)
|
|
159
|
+
conversion_args.set_defaults(separate_encoder_and_decoder_init=False)
|
|
160
|
+
|
|
161
|
+
conversion_args.add_argument(
|
|
162
|
+
"--no_beam_search_op",
|
|
163
|
+
required=False,
|
|
164
|
+
action="store_true",
|
|
165
|
+
help="Do not produce model with WhisperBeamSearch op, which chains encdecinit and decoder models into one op.",
|
|
166
|
+
)
|
|
167
|
+
conversion_args.set_defaults(no_beam_search_op=False)
|
|
168
|
+
|
|
169
|
+
conversion_args.add_argument(
|
|
170
|
+
"--state_dict_path",
|
|
171
|
+
type=str,
|
|
172
|
+
default="",
|
|
173
|
+
help="Filepath to load pre-trained model with custom state dictionary (e.g. pytorch_model.bin)",
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
#############################################################
|
|
177
|
+
# Optional inputs for Whisper
|
|
178
|
+
# (listed below in the order that WhisperBeamSearch expects)
|
|
179
|
+
#############################################################
|
|
180
|
+
|
|
181
|
+
optional_inputs.add_argument(
|
|
182
|
+
"-v",
|
|
183
|
+
"--use_vocab_mask",
|
|
184
|
+
required=False,
|
|
185
|
+
action="store_true",
|
|
186
|
+
help="Use vocab_mask as an extra graph input to enable specific logits processing",
|
|
187
|
+
)
|
|
188
|
+
optional_inputs.set_defaults(use_vocab_mask=False)
|
|
189
|
+
|
|
190
|
+
optional_inputs.add_argument(
|
|
191
|
+
"-u",
|
|
192
|
+
"--use_prefix_vocab_mask",
|
|
193
|
+
required=False,
|
|
194
|
+
action="store_true",
|
|
195
|
+
help="Use prefix_vocab_mask as an extra graph input to enable specific logits processing",
|
|
196
|
+
)
|
|
197
|
+
optional_inputs.set_defaults(use_prefix_vocab_mask=False)
|
|
198
|
+
|
|
199
|
+
optional_inputs.add_argument(
|
|
200
|
+
"-f",
|
|
201
|
+
"--use_forced_decoder_ids",
|
|
202
|
+
required=False,
|
|
203
|
+
action="store_true",
|
|
204
|
+
help="Use decoder_input_ids as an extra graph input to the beam search op",
|
|
205
|
+
)
|
|
206
|
+
optional_inputs.set_defaults(use_forced_decoder_ids=False)
|
|
207
|
+
|
|
208
|
+
optional_inputs.add_argument(
|
|
209
|
+
"-l",
|
|
210
|
+
"--use_logits_processor",
|
|
211
|
+
required=False,
|
|
212
|
+
action="store_true",
|
|
213
|
+
help="Use logits_processor as an extra graph input to enable specific logits processing",
|
|
214
|
+
)
|
|
215
|
+
optional_inputs.set_defaults(use_specific_logits_processor=False)
|
|
216
|
+
|
|
217
|
+
optional_inputs.add_argument(
|
|
218
|
+
"--collect_cross_qk",
|
|
219
|
+
required=False,
|
|
220
|
+
action="store_true",
|
|
221
|
+
help="Beam search model collect stacked cross QK.",
|
|
222
|
+
)
|
|
223
|
+
optional_inputs.set_defaults(collect_cross_qk=False)
|
|
224
|
+
|
|
225
|
+
optional_inputs.add_argument(
|
|
226
|
+
"--extra_decoding_ids",
|
|
227
|
+
required=False,
|
|
228
|
+
action="store_true",
|
|
229
|
+
help="Need extra starting decoding ids for some feature like cross qk. Default if false.",
|
|
230
|
+
)
|
|
231
|
+
optional_inputs.set_defaults(extra_decoding_ids=False)
|
|
232
|
+
|
|
233
|
+
optional_inputs.add_argument(
|
|
234
|
+
"-t",
|
|
235
|
+
"--use_temperature",
|
|
236
|
+
required=False,
|
|
237
|
+
action="store_true",
|
|
238
|
+
help="Use temperature as an extra graph input for the WhisperBeamSearch op",
|
|
239
|
+
)
|
|
240
|
+
optional_inputs.set_defaults(use_temperature=False)
|
|
241
|
+
|
|
242
|
+
optional_inputs.add_argument(
|
|
243
|
+
"--no_repeat_ngram_size",
|
|
244
|
+
type=int,
|
|
245
|
+
default=0,
|
|
246
|
+
help="default to 0",
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
#############################################################
|
|
250
|
+
# Optional outputs for Whisper
|
|
251
|
+
# (listed below in the order that WhisperBeamSearch expects)
|
|
252
|
+
#############################################################
|
|
253
|
+
|
|
254
|
+
optional_outputs.add_argument(
|
|
255
|
+
"--output_sequence_scores",
|
|
256
|
+
required=False,
|
|
257
|
+
action="store_true",
|
|
258
|
+
help="Beam search model output scores for each generated sequence.",
|
|
259
|
+
)
|
|
260
|
+
optional_outputs.set_defaults(output_sequence_scores=False)
|
|
261
|
+
|
|
262
|
+
optional_outputs.add_argument(
|
|
263
|
+
"--output_scores",
|
|
264
|
+
required=False,
|
|
265
|
+
action="store_true",
|
|
266
|
+
help="Beam search model output scores over vocab per generated token.",
|
|
267
|
+
)
|
|
268
|
+
optional_outputs.set_defaults(output_scores=False)
|
|
269
|
+
|
|
270
|
+
optional_outputs.add_argument(
|
|
271
|
+
"--output_cross_qk",
|
|
272
|
+
required=False,
|
|
273
|
+
action="store_true",
|
|
274
|
+
help="Beam search model output collected qk as output. Also hint collect_cross_qk",
|
|
275
|
+
)
|
|
276
|
+
optional_outputs.set_defaults(output_cross_qk=False)
|
|
277
|
+
|
|
278
|
+
optional_outputs.add_argument(
|
|
279
|
+
"--cross_qk_onnx_model",
|
|
280
|
+
required=False,
|
|
281
|
+
type=str,
|
|
282
|
+
default=None,
|
|
283
|
+
help="The model which consumes cross_qk outputs.",
|
|
284
|
+
)
|
|
285
|
+
|
|
286
|
+
optional_outputs.add_argument(
|
|
287
|
+
"--output_no_speech_probs",
|
|
288
|
+
required=False,
|
|
289
|
+
action="store_true",
|
|
290
|
+
help="Beam search model output no speech probs which is computed from the encoder/context-decoder graph.",
|
|
291
|
+
)
|
|
292
|
+
optional_outputs.set_defaults(output_no_speech_probs=False)
|
|
293
|
+
|
|
294
|
+
###################################
|
|
295
|
+
# Quantization options for Whisper
|
|
296
|
+
###################################
|
|
297
|
+
|
|
298
|
+
quant_args.add_argument(
|
|
299
|
+
"--quantize_embedding_layer",
|
|
300
|
+
required=False,
|
|
301
|
+
action="store_true",
|
|
302
|
+
help="Quantize MatMul, GEMM, and Gather.",
|
|
303
|
+
)
|
|
304
|
+
quant_args.set_defaults(quantize_embedding_layer=False)
|
|
305
|
+
|
|
306
|
+
quant_args.add_argument(
|
|
307
|
+
"--quantize_per_channel",
|
|
308
|
+
required=False,
|
|
309
|
+
action="store_true",
|
|
310
|
+
help="Quantize weights per each channel.",
|
|
311
|
+
)
|
|
312
|
+
quant_args.set_defaults(quantize_per_channel=False)
|
|
313
|
+
|
|
314
|
+
quant_args.add_argument(
|
|
315
|
+
"--quantize_reduce_range",
|
|
316
|
+
required=False,
|
|
317
|
+
action="store_true",
|
|
318
|
+
help="Quantize weights with 7 bits.",
|
|
319
|
+
)
|
|
320
|
+
quant_args.set_defaults(quantize_reduce_range=False)
|
|
321
|
+
|
|
322
|
+
args = parser.parse_args(argv)
|
|
323
|
+
args.collect_cross_qk = args.collect_cross_qk or args.output_cross_qk
|
|
324
|
+
|
|
325
|
+
return args
|
|
326
|
+
|
|
327
|
+
|
|
328
|
+
def export_onnx_models(
|
|
329
|
+
model_name_or_path,
|
|
330
|
+
model_impl,
|
|
331
|
+
cache_dir,
|
|
332
|
+
output_dir,
|
|
333
|
+
use_gpu,
|
|
334
|
+
use_external_data_format,
|
|
335
|
+
optimize_onnx,
|
|
336
|
+
precision,
|
|
337
|
+
verbose,
|
|
338
|
+
use_forced_decoder_ids: bool = False,
|
|
339
|
+
merge_encoder_and_decoder_init: bool = True,
|
|
340
|
+
overwrite: bool = False,
|
|
341
|
+
disable_auto_mixed_precision: bool = False,
|
|
342
|
+
use_int32_inputs: bool = True,
|
|
343
|
+
quantize_embedding_layer: bool = False,
|
|
344
|
+
quantize_per_channel: bool = False,
|
|
345
|
+
quantize_reduce_range: bool = False,
|
|
346
|
+
state_dict_path: str = "",
|
|
347
|
+
provider: str = "cpu",
|
|
348
|
+
):
|
|
349
|
+
device = torch.device("cuda:0" if use_gpu else "cpu")
|
|
350
|
+
|
|
351
|
+
models = WhisperHelper.load_model(
|
|
352
|
+
model_name_or_path, model_impl, cache_dir, device, merge_encoder_and_decoder_init, state_dict_path
|
|
353
|
+
)
|
|
354
|
+
config = models["decoder"].config
|
|
355
|
+
|
|
356
|
+
if (not use_external_data_format) and (config.num_hidden_layers > 24):
|
|
357
|
+
logger.info("Try use_external_data_format when model size > 2GB")
|
|
358
|
+
|
|
359
|
+
output_paths = []
|
|
360
|
+
for name, model in models.items():
|
|
361
|
+
print(f"========> Handling {name} model......")
|
|
362
|
+
model.to(device)
|
|
363
|
+
filename_suffix = "_" + name
|
|
364
|
+
|
|
365
|
+
onnx_path = WhisperHelper.get_onnx_path(
|
|
366
|
+
output_dir,
|
|
367
|
+
model_name_or_path,
|
|
368
|
+
suffix=filename_suffix,
|
|
369
|
+
new_folder=False,
|
|
370
|
+
)
|
|
371
|
+
|
|
372
|
+
if overwrite or not os.path.exists(onnx_path):
|
|
373
|
+
logger.info(f"Exporting ONNX model to {onnx_path}")
|
|
374
|
+
# We have to clone model before exporting onnx, otherwise verify_onnx will report large difference.
|
|
375
|
+
device_to_export = torch.device("cpu")
|
|
376
|
+
cloned_model = copy.deepcopy(model).to(device_to_export)
|
|
377
|
+
WhisperHelper.export_onnx(
|
|
378
|
+
cloned_model,
|
|
379
|
+
device_to_export,
|
|
380
|
+
onnx_path,
|
|
381
|
+
verbose,
|
|
382
|
+
use_external_data_format,
|
|
383
|
+
use_int32_inputs=use_int32_inputs,
|
|
384
|
+
)
|
|
385
|
+
else:
|
|
386
|
+
logger.info(f"Skip exporting: existed ONNX model {onnx_path}")
|
|
387
|
+
|
|
388
|
+
# Optimize ONNX graph. Note that we have not implemented graph optimization for Whisper yet.
|
|
389
|
+
if optimize_onnx or precision != Precision.FLOAT32:
|
|
390
|
+
output_path = WhisperHelper.get_onnx_path(
|
|
391
|
+
output_dir,
|
|
392
|
+
model_name_or_path,
|
|
393
|
+
suffix=filename_suffix + "_" + str(precision),
|
|
394
|
+
new_folder=False,
|
|
395
|
+
)
|
|
396
|
+
|
|
397
|
+
if overwrite or not os.path.exists(output_path):
|
|
398
|
+
if optimize_onnx:
|
|
399
|
+
logger.info(f"Optimizing model to {output_path}")
|
|
400
|
+
WhisperHelper.optimize_onnx(
|
|
401
|
+
onnx_path,
|
|
402
|
+
output_path,
|
|
403
|
+
precision == Precision.FLOAT16,
|
|
404
|
+
config.encoder_attention_heads,
|
|
405
|
+
config.d_model,
|
|
406
|
+
use_external_data_format,
|
|
407
|
+
auto_mixed_precision=not disable_auto_mixed_precision,
|
|
408
|
+
use_gpu=use_gpu,
|
|
409
|
+
provider=provider,
|
|
410
|
+
)
|
|
411
|
+
onnx_path = output_path
|
|
412
|
+
|
|
413
|
+
if precision == Precision.INT8:
|
|
414
|
+
quantization.quantize_dynamic(
|
|
415
|
+
onnx_path,
|
|
416
|
+
output_path,
|
|
417
|
+
op_types_to_quantize=(
|
|
418
|
+
["MatMul", "Gemm", "Gather"] if quantize_embedding_layer else ["MatMul", "Gemm"]
|
|
419
|
+
),
|
|
420
|
+
use_external_data_format=use_external_data_format,
|
|
421
|
+
per_channel=quantize_per_channel,
|
|
422
|
+
reduce_range=quantize_reduce_range,
|
|
423
|
+
extra_options={"MatMulConstBOnly": True},
|
|
424
|
+
)
|
|
425
|
+
else:
|
|
426
|
+
logger.info(f"Skip optimizing: existing ONNX model {onnx_path}")
|
|
427
|
+
else:
|
|
428
|
+
output_path = onnx_path
|
|
429
|
+
|
|
430
|
+
ort_session = create_onnxruntime_session(
|
|
431
|
+
output_path,
|
|
432
|
+
use_gpu=use_gpu,
|
|
433
|
+
provider=provider,
|
|
434
|
+
)
|
|
435
|
+
assert ort_session is not None
|
|
436
|
+
|
|
437
|
+
output_paths.append(output_path)
|
|
438
|
+
|
|
439
|
+
return output_paths
|
|
440
|
+
|
|
441
|
+
|
|
442
|
+
def main(argv=None):
|
|
443
|
+
args = parse_arguments(argv)
|
|
444
|
+
|
|
445
|
+
setup_logger(args.verbose)
|
|
446
|
+
|
|
447
|
+
logger.info(f"Arguments:{args}")
|
|
448
|
+
|
|
449
|
+
cache_dir = args.cache_dir
|
|
450
|
+
output_dir = args.output if not args.output.endswith(".onnx") else os.path.dirname(args.output)
|
|
451
|
+
prepare_environment(cache_dir, output_dir, args.use_gpu)
|
|
452
|
+
|
|
453
|
+
if args.precision == Precision.FLOAT16:
|
|
454
|
+
assert args.use_gpu, "fp16 requires --use_gpu"
|
|
455
|
+
|
|
456
|
+
if args.optimize_onnx:
|
|
457
|
+
logger.warning("Applying graph optimization for Whisper...")
|
|
458
|
+
|
|
459
|
+
output_paths = export_onnx_models(
|
|
460
|
+
args.model_name_or_path,
|
|
461
|
+
args.model_impl,
|
|
462
|
+
cache_dir,
|
|
463
|
+
output_dir,
|
|
464
|
+
args.use_gpu,
|
|
465
|
+
args.use_external_data_format,
|
|
466
|
+
args.optimize_onnx,
|
|
467
|
+
args.precision,
|
|
468
|
+
args.verbose,
|
|
469
|
+
args.use_forced_decoder_ids,
|
|
470
|
+
not args.separate_encoder_and_decoder_init,
|
|
471
|
+
args.overwrite,
|
|
472
|
+
args.disable_auto_mixed_precision,
|
|
473
|
+
not args.use_int64_inputs,
|
|
474
|
+
args.quantize_embedding_layer,
|
|
475
|
+
args.quantize_per_channel,
|
|
476
|
+
args.quantize_reduce_range,
|
|
477
|
+
args.state_dict_path,
|
|
478
|
+
args.provider,
|
|
479
|
+
)
|
|
480
|
+
|
|
481
|
+
max_diff = 0
|
|
482
|
+
if not args.no_beam_search_op:
|
|
483
|
+
logger.info("Chaining model ... :")
|
|
484
|
+
args.beam_model_output_dir = WhisperHelper.get_onnx_path(
|
|
485
|
+
output_dir,
|
|
486
|
+
args.model_name_or_path,
|
|
487
|
+
suffix="_beamsearch",
|
|
488
|
+
new_folder=False,
|
|
489
|
+
)
|
|
490
|
+
for path in output_paths:
|
|
491
|
+
if "encoder_decoder" in path:
|
|
492
|
+
args.encoder_path = path
|
|
493
|
+
elif "decoder" in path:
|
|
494
|
+
args.decoder_path = path
|
|
495
|
+
chain_model(args)
|
|
496
|
+
output_paths.append(args.beam_model_output_dir)
|
|
497
|
+
|
|
498
|
+
# Check chained model
|
|
499
|
+
ort_session = create_onnxruntime_session(
|
|
500
|
+
args.beam_model_output_dir,
|
|
501
|
+
use_gpu=args.use_gpu,
|
|
502
|
+
provider=args.provider,
|
|
503
|
+
)
|
|
504
|
+
device = torch.device("cuda:0" if args.use_gpu else "cpu")
|
|
505
|
+
|
|
506
|
+
# Wrap parity check in try-except to allow export to continue in case this produces an error
|
|
507
|
+
try:
|
|
508
|
+
with torch.no_grad():
|
|
509
|
+
# Verify batched decoding with prompts for whisper openai implementation
|
|
510
|
+
if args.model_impl == "openai" and args.use_forced_decoder_ids:
|
|
511
|
+
max_diff = WhisperHelper.verify_onnx(
|
|
512
|
+
args.model_name_or_path, cache_dir, ort_session, device, batch_size=2, prompt_mode=True
|
|
513
|
+
)
|
|
514
|
+
else:
|
|
515
|
+
max_diff = WhisperHelper.verify_onnx(args.model_name_or_path, cache_dir, ort_session, device)
|
|
516
|
+
if max_diff > 1e-4:
|
|
517
|
+
logger.warning("PyTorch and ONNX Runtime results are NOT close")
|
|
518
|
+
else:
|
|
519
|
+
logger.info("PyTorch and ONNX Runtime results are close")
|
|
520
|
+
except Exception as e:
|
|
521
|
+
logger.warning(
|
|
522
|
+
f"An error occurred while trying to verify parity between PyTorch and ONNX Runtime: {e}", exc_info=True
|
|
523
|
+
)
|
|
524
|
+
|
|
525
|
+
# Remove extra ONNX models saved in output directory
|
|
526
|
+
for fle in os.listdir(output_dir):
|
|
527
|
+
if "_beamsearch" not in fle:
|
|
528
|
+
os.remove(os.path.join(output_dir, fle))
|
|
529
|
+
output_paths = [args.beam_model_output_dir]
|
|
530
|
+
|
|
531
|
+
logger.info(f"Done! Outputs: {output_paths}")
|
|
532
|
+
return max_diff
|
|
533
|
+
|
|
534
|
+
|
|
535
|
+
if __name__ == "__main__":
|
|
536
|
+
main()
|