onnxruntime-directml 1.20.0__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxruntime/LICENSE +21 -0
- onnxruntime/Privacy.md +21 -0
- onnxruntime/ThirdPartyNotices.txt +6508 -0
- onnxruntime/__init__.py +78 -0
- onnxruntime/backend/__init__.py +6 -0
- onnxruntime/backend/backend.py +174 -0
- onnxruntime/backend/backend_rep.py +53 -0
- onnxruntime/capi/DirectML.dll +0 -0
- onnxruntime/capi/__init__.py +4 -0
- onnxruntime/capi/_ld_preload.py +7 -0
- onnxruntime/capi/_pybind_state.py +33 -0
- onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
- onnxruntime/capi/onnxruntime.dll +0 -0
- onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
- onnxruntime/capi/onnxruntime_inference_collection.py +1108 -0
- onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
- onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
- onnxruntime/capi/onnxruntime_validation.py +150 -0
- onnxruntime/capi/version_info.py +2 -0
- onnxruntime/datasets/__init__.py +17 -0
- onnxruntime/datasets/logreg_iris.onnx +0 -0
- onnxruntime/datasets/mul_1.onnx +0 -0
- onnxruntime/datasets/sigmoid.onnx +13 -0
- onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
- onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
- onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
- onnxruntime/quantization/__init__.py +16 -0
- onnxruntime/quantization/base_quantizer.py +532 -0
- onnxruntime/quantization/calibrate.py +1245 -0
- onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
- onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
- onnxruntime/quantization/execution_providers/qnn/preprocess.py +307 -0
- onnxruntime/quantization/execution_providers/qnn/quant_config.py +387 -0
- onnxruntime/quantization/fusions/__init__.py +3 -0
- onnxruntime/quantization/fusions/fusion.py +311 -0
- onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
- onnxruntime/quantization/fusions/fusion_layernorm.py +135 -0
- onnxruntime/quantization/matmul_4bits_quantizer.py +1480 -0
- onnxruntime/quantization/matmul_bnb4_quantizer.py +240 -0
- onnxruntime/quantization/onnx_model.py +580 -0
- onnxruntime/quantization/onnx_quantizer.py +1008 -0
- onnxruntime/quantization/operators/__init__.py +2 -0
- onnxruntime/quantization/operators/activation.py +119 -0
- onnxruntime/quantization/operators/argmax.py +18 -0
- onnxruntime/quantization/operators/attention.py +73 -0
- onnxruntime/quantization/operators/base_operator.py +26 -0
- onnxruntime/quantization/operators/binary_op.py +72 -0
- onnxruntime/quantization/operators/concat.py +62 -0
- onnxruntime/quantization/operators/conv.py +258 -0
- onnxruntime/quantization/operators/direct_q8.py +78 -0
- onnxruntime/quantization/operators/embed_layernorm.py +121 -0
- onnxruntime/quantization/operators/gather.py +64 -0
- onnxruntime/quantization/operators/gavgpool.py +62 -0
- onnxruntime/quantization/operators/gemm.py +166 -0
- onnxruntime/quantization/operators/lstm.py +117 -0
- onnxruntime/quantization/operators/matmul.py +231 -0
- onnxruntime/quantization/operators/maxpool.py +34 -0
- onnxruntime/quantization/operators/norm.py +40 -0
- onnxruntime/quantization/operators/pad.py +100 -0
- onnxruntime/quantization/operators/pooling.py +67 -0
- onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
- onnxruntime/quantization/operators/resize.py +34 -0
- onnxruntime/quantization/operators/softmax.py +74 -0
- onnxruntime/quantization/operators/split.py +63 -0
- onnxruntime/quantization/operators/where.py +87 -0
- onnxruntime/quantization/preprocess.py +141 -0
- onnxruntime/quantization/qdq_loss_debug.py +389 -0
- onnxruntime/quantization/qdq_quantizer.py +1187 -0
- onnxruntime/quantization/quant_utils.py +891 -0
- onnxruntime/quantization/quantize.py +748 -0
- onnxruntime/quantization/registry.py +106 -0
- onnxruntime/quantization/shape_inference.py +187 -0
- onnxruntime/quantization/tensor_quant_overrides.py +516 -0
- onnxruntime/tools/__init__.py +10 -0
- onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
- onnxruntime/tools/convert_onnx_models_to_ort.py +377 -0
- onnxruntime/tools/file_utils.py +46 -0
- onnxruntime/tools/logger.py +11 -0
- onnxruntime/tools/make_dynamic_shape_fixed.py +72 -0
- onnxruntime/tools/mobile_helpers/__init__.py +0 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +33 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
- onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
- onnxruntime/tools/mobile_helpers/usability_checker.py +739 -0
- onnxruntime/tools/offline_tuning.py +169 -0
- onnxruntime/tools/onnx_model_utils.py +413 -0
- onnxruntime/tools/onnx_randomizer.py +85 -0
- onnxruntime/tools/onnxruntime_test.py +164 -0
- onnxruntime/tools/optimize_onnx_model.py +55 -0
- onnxruntime/tools/ort_format_model/__init__.py +25 -0
- onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +663 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
- onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
- onnxruntime/tools/ort_format_model/types.py +84 -0
- onnxruntime/tools/ort_format_model/utils.py +62 -0
- onnxruntime/tools/pytorch_export_contrib_ops.py +108 -0
- onnxruntime/tools/pytorch_export_helpers.py +131 -0
- onnxruntime/tools/qdq_helpers/__init__.py +0 -0
- onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
- onnxruntime/tools/reduced_build_config_parser.py +202 -0
- onnxruntime/tools/symbolic_shape_infer.py +3016 -0
- onnxruntime/tools/update_onnx_opset.py +31 -0
- onnxruntime/transformers/__init__.py +8 -0
- onnxruntime/transformers/affinity_helper.py +40 -0
- onnxruntime/transformers/benchmark.py +944 -0
- onnxruntime/transformers/benchmark_helper.py +646 -0
- onnxruntime/transformers/bert_perf_test.py +634 -0
- onnxruntime/transformers/bert_test_data.py +642 -0
- onnxruntime/transformers/compare_bert_results.py +246 -0
- onnxruntime/transformers/constants.py +47 -0
- onnxruntime/transformers/convert_generation.py +3124 -0
- onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
- onnxruntime/transformers/convert_to_packing_mode.py +387 -0
- onnxruntime/transformers/dynamo_onnx_helper.py +104 -0
- onnxruntime/transformers/float16.py +501 -0
- onnxruntime/transformers/fusion_attention.py +1235 -0
- onnxruntime/transformers/fusion_attention_clip.py +257 -0
- onnxruntime/transformers/fusion_attention_sam2.py +534 -0
- onnxruntime/transformers/fusion_attention_unet.py +1304 -0
- onnxruntime/transformers/fusion_attention_vae.py +301 -0
- onnxruntime/transformers/fusion_bart_attention.py +640 -0
- onnxruntime/transformers/fusion_base.py +137 -0
- onnxruntime/transformers/fusion_bias_add.py +58 -0
- onnxruntime/transformers/fusion_biasgelu.py +66 -0
- onnxruntime/transformers/fusion_biassplitgelu.py +111 -0
- onnxruntime/transformers/fusion_conformer_attention.py +143 -0
- onnxruntime/transformers/fusion_embedlayer.py +811 -0
- onnxruntime/transformers/fusion_fastgelu.py +360 -0
- onnxruntime/transformers/fusion_gelu.py +259 -0
- onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
- onnxruntime/transformers/fusion_gemmfastgelu.py +122 -0
- onnxruntime/transformers/fusion_gpt_attention.py +546 -0
- onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
- onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
- onnxruntime/transformers/fusion_group_norm.py +179 -0
- onnxruntime/transformers/fusion_layernorm.py +465 -0
- onnxruntime/transformers/fusion_nhwc_conv.py +100 -0
- onnxruntime/transformers/fusion_options.py +340 -0
- onnxruntime/transformers/fusion_qordered_attention.py +421 -0
- onnxruntime/transformers/fusion_qordered_gelu.py +119 -0
- onnxruntime/transformers/fusion_qordered_layernorm.py +123 -0
- onnxruntime/transformers/fusion_qordered_matmul.py +217 -0
- onnxruntime/transformers/fusion_quickgelu.py +74 -0
- onnxruntime/transformers/fusion_reshape.py +173 -0
- onnxruntime/transformers/fusion_rotary_attention.py +1592 -0
- onnxruntime/transformers/fusion_shape.py +110 -0
- onnxruntime/transformers/fusion_simplified_layernorm.py +159 -0
- onnxruntime/transformers/fusion_skip_group_norm.py +255 -0
- onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
- onnxruntime/transformers/fusion_transpose.py +168 -0
- onnxruntime/transformers/fusion_utils.py +307 -0
- onnxruntime/transformers/huggingface_models.py +167 -0
- onnxruntime/transformers/import_utils.py +20 -0
- onnxruntime/transformers/io_binding_helper.py +442 -0
- onnxruntime/transformers/large_model_exporter.py +395 -0
- onnxruntime/transformers/machine_info.py +221 -0
- onnxruntime/transformers/metrics.py +164 -0
- onnxruntime/transformers/models/bart/__init__.py +12 -0
- onnxruntime/transformers/models/bart/export.py +98 -0
- onnxruntime/transformers/models/bert/__init__.py +12 -0
- onnxruntime/transformers/models/bert/eval_squad.py +329 -0
- onnxruntime/transformers/models/gpt2/__init__.py +12 -0
- onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
- onnxruntime/transformers/models/gpt2/convert_to_onnx.py +561 -0
- onnxruntime/transformers/models/gpt2/gpt2_helper.py +1032 -0
- onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
- onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
- onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
- onnxruntime/transformers/models/llama/__init__.py +12 -0
- onnxruntime/transformers/models/llama/benchmark.py +703 -0
- onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
- onnxruntime/transformers/models/llama/benchmark_e2e.py +606 -0
- onnxruntime/transformers/models/llama/convert_to_onnx.py +1027 -0
- onnxruntime/transformers/models/llama/dist_settings.py +57 -0
- onnxruntime/transformers/models/llama/llama_inputs.py +503 -0
- onnxruntime/transformers/models/llama/llama_parity.py +309 -0
- onnxruntime/transformers/models/llama/llama_torch.py +47 -0
- onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
- onnxruntime/transformers/models/longformer/__init__.py +12 -0
- onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
- onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
- onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
- onnxruntime/transformers/models/longformer/longformer_helper.py +77 -0
- onnxruntime/transformers/models/phi2/__init__.py +12 -0
- onnxruntime/transformers/models/phi2/convert_to_onnx.py +576 -0
- onnxruntime/transformers/models/phi2/inference_example.py +414 -0
- onnxruntime/transformers/models/sam2/__init__.py +12 -0
- onnxruntime/transformers/models/sam2/benchmark_sam2.py +625 -0
- onnxruntime/transformers/models/sam2/convert_to_onnx.py +260 -0
- onnxruntime/transformers/models/sam2/image_decoder.py +273 -0
- onnxruntime/transformers/models/sam2/image_encoder.py +186 -0
- onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
- onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
- onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
- onnxruntime/transformers/models/sam2/sam2_demo.py +322 -0
- onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +280 -0
- onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
- onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark.py +1429 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +102 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +268 -0
- onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1319 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1181 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder.py +296 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +388 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
- onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +350 -0
- onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
- onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
- onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
- onnxruntime/transformers/models/t5/__init__.py +12 -0
- onnxruntime/transformers/models/t5/convert_to_onnx.py +278 -0
- onnxruntime/transformers/models/t5/past_helper.py +150 -0
- onnxruntime/transformers/models/t5/t5_decoder.py +438 -0
- onnxruntime/transformers/models/t5/t5_encoder.py +171 -0
- onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +299 -0
- onnxruntime/transformers/models/t5/t5_helper.py +272 -0
- onnxruntime/transformers/models/whisper/__init__.py +12 -0
- onnxruntime/transformers/models/whisper/benchmark.py +610 -0
- onnxruntime/transformers/models/whisper/benchmark_all.py +528 -0
- onnxruntime/transformers/models/whisper/convert_to_onnx.py +536 -0
- onnxruntime/transformers/models/whisper/whisper_chain.py +329 -0
- onnxruntime/transformers/models/whisper/whisper_decoder.py +402 -0
- onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
- onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +306 -0
- onnxruntime/transformers/models/whisper/whisper_helper.py +524 -0
- onnxruntime/transformers/models/whisper/whisper_openai_helper.py +84 -0
- onnxruntime/transformers/onnx_exporter.py +717 -0
- onnxruntime/transformers/onnx_model.py +1569 -0
- onnxruntime/transformers/onnx_model_bart.py +142 -0
- onnxruntime/transformers/onnx_model_bert.py +481 -0
- onnxruntime/transformers/onnx_model_bert_keras.py +475 -0
- onnxruntime/transformers/onnx_model_bert_tf.py +589 -0
- onnxruntime/transformers/onnx_model_clip.py +40 -0
- onnxruntime/transformers/onnx_model_conformer.py +33 -0
- onnxruntime/transformers/onnx_model_gpt2.py +101 -0
- onnxruntime/transformers/onnx_model_phi.py +930 -0
- onnxruntime/transformers/onnx_model_sam2.py +138 -0
- onnxruntime/transformers/onnx_model_t5.py +791 -0
- onnxruntime/transformers/onnx_model_tnlr.py +227 -0
- onnxruntime/transformers/onnx_model_unet.py +259 -0
- onnxruntime/transformers/onnx_model_vae.py +43 -0
- onnxruntime/transformers/onnx_utils.py +55 -0
- onnxruntime/transformers/optimizer.py +612 -0
- onnxruntime/transformers/profiler.py +725 -0
- onnxruntime/transformers/quantize_helper.py +76 -0
- onnxruntime/transformers/shape_infer_helper.py +122 -0
- onnxruntime/transformers/shape_optimizer.py +401 -0
- onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
- onnxruntime_directml-1.20.0.dist-info/METADATA +187 -0
- onnxruntime_directml-1.20.0.dist-info/RECORD +305 -0
- onnxruntime_directml-1.20.0.dist-info/WHEEL +5 -0
- onnxruntime_directml-1.20.0.dist-info/entry_points.txt +2 -0
- onnxruntime_directml-1.20.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,402 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License. See License.txt in the project root for
|
|
4
|
+
# license information.
|
|
5
|
+
# --------------------------------------------------------------------------
|
|
6
|
+
|
|
7
|
+
import logging
|
|
8
|
+
import os
|
|
9
|
+
import tempfile
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
from typing import List, Optional, Union
|
|
12
|
+
|
|
13
|
+
import numpy
|
|
14
|
+
import onnx
|
|
15
|
+
import torch
|
|
16
|
+
from io_binding_helper import TypeHelper
|
|
17
|
+
from models.t5.past_helper import PastKeyValuesHelper
|
|
18
|
+
from onnx_model import OnnxModel
|
|
19
|
+
from torch_onnx_export_helper import torch_onnx_export
|
|
20
|
+
from transformers import WhisperConfig, file_utils
|
|
21
|
+
from whisper_openai_helper import WhisperDecoderInitOpenai
|
|
22
|
+
|
|
23
|
+
from onnxruntime import InferenceSession
|
|
24
|
+
|
|
25
|
+
logger = logging.getLogger(__name__)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class WhisperDecoderInit(torch.nn.Module):
|
|
29
|
+
"""A Whisper decoder to create initial past key values.
|
|
30
|
+
This model is only called once during starting decoding.
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
def __init__(
|
|
34
|
+
self,
|
|
35
|
+
decoder: torch.nn.Module,
|
|
36
|
+
config: WhisperConfig,
|
|
37
|
+
decoder_start_token_id: Optional[int] = None,
|
|
38
|
+
):
|
|
39
|
+
super().__init__()
|
|
40
|
+
self.decoder = decoder
|
|
41
|
+
self.config = config
|
|
42
|
+
self.decoder_start_token_id = (
|
|
43
|
+
decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
def forward(
|
|
47
|
+
self,
|
|
48
|
+
decoder_input_ids: torch.Tensor,
|
|
49
|
+
encoder_hidden_states: torch.FloatTensor,
|
|
50
|
+
):
|
|
51
|
+
encoder_outputs = file_utils.ModelOutput()
|
|
52
|
+
encoder_outputs["last_hidden_state"] = encoder_hidden_states
|
|
53
|
+
encoder_outputs["hidden_states"] = None
|
|
54
|
+
encoder_outputs["attentions"] = None
|
|
55
|
+
|
|
56
|
+
out = self.decoder.model(
|
|
57
|
+
None,
|
|
58
|
+
encoder_outputs=encoder_outputs,
|
|
59
|
+
decoder_input_ids=decoder_input_ids,
|
|
60
|
+
past_key_values=None,
|
|
61
|
+
use_cache=True,
|
|
62
|
+
return_dict=True,
|
|
63
|
+
)
|
|
64
|
+
logits = self.decoder.proj_out(out[0])
|
|
65
|
+
return logits, out.past_key_values, out.encoder_last_hidden_state
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class WhisperDecoder(torch.nn.Module):
|
|
69
|
+
"""A Whisper decoder with past key values"""
|
|
70
|
+
|
|
71
|
+
def __init__(self, decoder, config, model_impl: str = "hf", model: torch.nn.Module = None):
|
|
72
|
+
super().__init__()
|
|
73
|
+
self.decoder = decoder
|
|
74
|
+
self.config = config
|
|
75
|
+
self.model_impl = model_impl
|
|
76
|
+
if model is not None:
|
|
77
|
+
self.whisper_decoder_openai_init = WhisperDecoderInitOpenai(model, decoder)
|
|
78
|
+
|
|
79
|
+
def forward(self, decoder_input_ids, *past):
|
|
80
|
+
encoder_outputs = file_utils.ModelOutput()
|
|
81
|
+
dummy_encoder_hidden_states = torch.randn((decoder_input_ids.shape[0], 3000, int(self.config.d_model)))
|
|
82
|
+
encoder_outputs["last_hidden_state"] = dummy_encoder_hidden_states
|
|
83
|
+
encoder_outputs["hidden_states"] = dummy_encoder_hidden_states
|
|
84
|
+
encoder_outputs["attentions"] = None
|
|
85
|
+
|
|
86
|
+
if self.model_impl == "openai":
|
|
87
|
+
dummy_encoder_hidden_states.unsqueeze(0)
|
|
88
|
+
dec_out, present = self.whisper_decoder_openai_init(
|
|
89
|
+
decoder_input_ids, dummy_encoder_hidden_states, past=past
|
|
90
|
+
)
|
|
91
|
+
return dec_out, present
|
|
92
|
+
|
|
93
|
+
if len(past) == 0:
|
|
94
|
+
past_key_values = None
|
|
95
|
+
else:
|
|
96
|
+
past_key_values = PastKeyValuesHelper.back_group_by_layer(past)
|
|
97
|
+
|
|
98
|
+
decoder_out = self.decoder(
|
|
99
|
+
None,
|
|
100
|
+
encoder_outputs=encoder_outputs,
|
|
101
|
+
decoder_input_ids=decoder_input_ids,
|
|
102
|
+
past_key_values=past_key_values,
|
|
103
|
+
use_cache=True,
|
|
104
|
+
return_dict=True,
|
|
105
|
+
)
|
|
106
|
+
logits = decoder_out[0]
|
|
107
|
+
present_self, _ = PastKeyValuesHelper.group_by_self_and_cross(decoder_out.past_key_values)
|
|
108
|
+
return logits, present_self
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
class WhisperDecoderInputs:
|
|
112
|
+
def __init__(
|
|
113
|
+
self,
|
|
114
|
+
decoder_input_ids,
|
|
115
|
+
past_key_values=None,
|
|
116
|
+
):
|
|
117
|
+
self.decoder_input_ids: torch.LongTensor = decoder_input_ids
|
|
118
|
+
self.past_key_values: Union[List[torch.FloatTensor], List[torch.HalfTensor], None] = past_key_values
|
|
119
|
+
|
|
120
|
+
@staticmethod
|
|
121
|
+
def create_dummy(
|
|
122
|
+
config: WhisperConfig,
|
|
123
|
+
batch_size: int,
|
|
124
|
+
encode_sequence_length: int,
|
|
125
|
+
past_decode_sequence_length: int,
|
|
126
|
+
device: torch.device,
|
|
127
|
+
float16: bool = False,
|
|
128
|
+
use_int32_inputs: bool = False,
|
|
129
|
+
model_impl: str = "hf",
|
|
130
|
+
): # -> WhisperDecoderInputs:
|
|
131
|
+
"""Create dummy inputs for WhisperDecoder.
|
|
132
|
+
|
|
133
|
+
Args:
|
|
134
|
+
decoder: decoder
|
|
135
|
+
batch_size (int): batch size
|
|
136
|
+
encode_sequence_length (int): sequence length of input_ids for encoder
|
|
137
|
+
past_decode_sequence_length (int): past sequence length of input_ids for decoder
|
|
138
|
+
device (torch.device): device of output tensors
|
|
139
|
+
float16 (bool): whether the model uses float32 or float16 in input
|
|
140
|
+
use_int32_inputs(bool): whether use int32 instead of int64 for some inputs
|
|
141
|
+
|
|
142
|
+
Returns:
|
|
143
|
+
WhisperDecoderInputs: dummy inputs for decoder
|
|
144
|
+
"""
|
|
145
|
+
num_attention_heads: int = config.encoder_attention_heads
|
|
146
|
+
num_layers: int = config.decoder_layers # + config.encoder_layers
|
|
147
|
+
vocab_size: int = config.vocab_size
|
|
148
|
+
|
|
149
|
+
# Use head_size, use hidden_size / num_attention_heads here.
|
|
150
|
+
# For example, whisper-large, d_model=1280 and num_heads=20
|
|
151
|
+
head_size: int = config.d_model // config.encoder_attention_heads
|
|
152
|
+
|
|
153
|
+
sequence_length: int = 1 # fixed for decoding
|
|
154
|
+
decoder_input_ids = torch.randint(
|
|
155
|
+
low=0,
|
|
156
|
+
high=vocab_size - 1,
|
|
157
|
+
size=(batch_size, sequence_length),
|
|
158
|
+
dtype=(torch.int32 if use_int32_inputs else torch.int64),
|
|
159
|
+
device=device,
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
float_type = torch.float16 if float16 else torch.float32
|
|
163
|
+
|
|
164
|
+
if past_decode_sequence_length > 0:
|
|
165
|
+
self_attention_past_shape = [
|
|
166
|
+
batch_size,
|
|
167
|
+
num_attention_heads,
|
|
168
|
+
past_decode_sequence_length,
|
|
169
|
+
head_size,
|
|
170
|
+
]
|
|
171
|
+
cross_attention_past_shape = [
|
|
172
|
+
batch_size,
|
|
173
|
+
num_attention_heads,
|
|
174
|
+
encode_sequence_length if model_impl == "hf" else past_decode_sequence_length,
|
|
175
|
+
head_size,
|
|
176
|
+
]
|
|
177
|
+
|
|
178
|
+
past = []
|
|
179
|
+
for _ in range(2 * num_layers):
|
|
180
|
+
past.append(torch.rand(self_attention_past_shape, dtype=float_type, device=device))
|
|
181
|
+
|
|
182
|
+
for _ in range(2 * num_layers):
|
|
183
|
+
past.append(torch.rand(cross_attention_past_shape, dtype=float_type, device=device))
|
|
184
|
+
else:
|
|
185
|
+
past = None
|
|
186
|
+
|
|
187
|
+
return WhisperDecoderInputs(decoder_input_ids, past)
|
|
188
|
+
|
|
189
|
+
def to_list(self) -> List:
|
|
190
|
+
input_list = [self.decoder_input_ids]
|
|
191
|
+
if self.past_key_values:
|
|
192
|
+
input_list.extend(self.past_key_values)
|
|
193
|
+
return input_list
|
|
194
|
+
|
|
195
|
+
def to_fp32(self):
|
|
196
|
+
past = [p.to(dtype=torch.float32) for p in self.past_key_values] if self.past_key_values else None
|
|
197
|
+
return WhisperDecoderInputs(
|
|
198
|
+
self.decoder_input_ids.clone(),
|
|
199
|
+
past,
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
class WhisperDecoderHelper:
|
|
204
|
+
@staticmethod
|
|
205
|
+
def export_onnx(
|
|
206
|
+
decoder: WhisperDecoder,
|
|
207
|
+
device: torch.device,
|
|
208
|
+
onnx_model_path: str,
|
|
209
|
+
verbose: bool = True,
|
|
210
|
+
use_external_data_format: bool = False,
|
|
211
|
+
use_int32_inputs: bool = False,
|
|
212
|
+
):
|
|
213
|
+
"""Export decoder to ONNX
|
|
214
|
+
|
|
215
|
+
Args:
|
|
216
|
+
decoder (Union[WhisperDecoder, WhisperDecoderNoPastState]): decoder object
|
|
217
|
+
device (torch.device): device of decoder object
|
|
218
|
+
onnx_model_path (str): onnx path
|
|
219
|
+
verbose (bool, optional): print verbose information. Defaults to True.
|
|
220
|
+
use_external_data_format (bool, optional): use external data format or not. Defaults to False.
|
|
221
|
+
use_int32_inputs (bool, optional): use int32 inputs
|
|
222
|
+
"""
|
|
223
|
+
assert isinstance(decoder, (WhisperDecoder, WhisperDecoderInit))
|
|
224
|
+
|
|
225
|
+
inputs = WhisperDecoderInputs.create_dummy(
|
|
226
|
+
decoder.config,
|
|
227
|
+
batch_size=2,
|
|
228
|
+
encode_sequence_length=3000,
|
|
229
|
+
past_decode_sequence_length=6 if isinstance(decoder, WhisperDecoder) else 0,
|
|
230
|
+
device=device,
|
|
231
|
+
use_int32_inputs=use_int32_inputs,
|
|
232
|
+
model_impl=decoder.model_impl,
|
|
233
|
+
)
|
|
234
|
+
input_list = inputs.to_list()
|
|
235
|
+
|
|
236
|
+
# Fix past disappearing bug - duplicate first past entry
|
|
237
|
+
# input_list.insert(2, input_list[2])
|
|
238
|
+
|
|
239
|
+
past_names = PastKeyValuesHelper.get_past_names(decoder.config.decoder_layers, present=False)
|
|
240
|
+
present_names = PastKeyValuesHelper.get_past_names(decoder.config.decoder_layers, present=True)
|
|
241
|
+
present_self_names = present_names[: 2 * decoder.config.decoder_layers]
|
|
242
|
+
|
|
243
|
+
input_past_names = past_names if isinstance(decoder, WhisperDecoder) else []
|
|
244
|
+
output_present_names = present_self_names if isinstance(decoder, WhisperDecoder) else present_names
|
|
245
|
+
output_names = ["logits", *output_present_names]
|
|
246
|
+
|
|
247
|
+
# Shape of input tensors (sequence_length==1):
|
|
248
|
+
# input_ids: (batch_size, sequence_length)
|
|
249
|
+
# past_self_*: (batch_size, num_heads, past_decode_sequence_length, head_size)
|
|
250
|
+
# past_cross_*: (batch_size, num_heads, encode_sequence_length, head_size)
|
|
251
|
+
|
|
252
|
+
# Shape of output tensors:
|
|
253
|
+
# logits: (batch_size, sequence_length, vocab_size)
|
|
254
|
+
# past_self_*: (batch_size, num_heads, past_decode_sequence_length + sequence_length, head_size)
|
|
255
|
+
# past_cross_*: (batch_size, num_heads, encode_sequence_length, head_size)
|
|
256
|
+
|
|
257
|
+
input_names = ["input_ids"]
|
|
258
|
+
input_names.extend(input_past_names)
|
|
259
|
+
|
|
260
|
+
dynamic_axes = {
|
|
261
|
+
"input_ids": {0: "batch_size"},
|
|
262
|
+
"encoder_hidden_states": {0: "batch_size", 1: "encode_sequence_length / 2"},
|
|
263
|
+
"logits": {0: "batch_size", 1: "sequence_length"},
|
|
264
|
+
}
|
|
265
|
+
|
|
266
|
+
for name in input_past_names:
|
|
267
|
+
dynamic_axes[name] = {
|
|
268
|
+
0: "batch_size",
|
|
269
|
+
2: "past_decode_sequence_length" if "self" in name else "encode_sequence_length",
|
|
270
|
+
}
|
|
271
|
+
|
|
272
|
+
for name in output_present_names:
|
|
273
|
+
if "cross" in name:
|
|
274
|
+
dynamic_axes[name] = {0: "batch_size", 2: "encode_sequence_length"}
|
|
275
|
+
else: # self attention past state
|
|
276
|
+
if isinstance(decoder, WhisperDecoder):
|
|
277
|
+
dynamic_axes[name] = {
|
|
278
|
+
0: "batch_size",
|
|
279
|
+
2: "past_decode_sequence_length + 1",
|
|
280
|
+
}
|
|
281
|
+
else:
|
|
282
|
+
dynamic_axes[name] = {
|
|
283
|
+
0: "batch_size",
|
|
284
|
+
# 2: 'sequence_length'
|
|
285
|
+
}
|
|
286
|
+
|
|
287
|
+
Path(onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
|
|
288
|
+
|
|
289
|
+
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
|
290
|
+
temp_onnx_model_path = os.path.join(tmp_dir_name, "decoder.onnx")
|
|
291
|
+
Path(temp_onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
|
|
292
|
+
torch_onnx_export(
|
|
293
|
+
decoder,
|
|
294
|
+
args=tuple(input_list),
|
|
295
|
+
f=temp_onnx_model_path if use_external_data_format else onnx_model_path,
|
|
296
|
+
export_params=True,
|
|
297
|
+
input_names=input_names,
|
|
298
|
+
output_names=output_names,
|
|
299
|
+
dynamic_axes=dynamic_axes,
|
|
300
|
+
opset_version=17,
|
|
301
|
+
do_constant_folding=True,
|
|
302
|
+
use_external_data_format=use_external_data_format,
|
|
303
|
+
verbose=verbose,
|
|
304
|
+
)
|
|
305
|
+
|
|
306
|
+
if use_external_data_format:
|
|
307
|
+
model = onnx.load_model(temp_onnx_model_path, load_external_data=True)
|
|
308
|
+
OnnxModel.save(
|
|
309
|
+
model,
|
|
310
|
+
onnx_model_path,
|
|
311
|
+
save_as_external_data=True,
|
|
312
|
+
all_tensors_to_one_file=True,
|
|
313
|
+
)
|
|
314
|
+
|
|
315
|
+
@staticmethod
|
|
316
|
+
def onnxruntime_inference(ort_session, inputs: WhisperDecoderInputs):
|
|
317
|
+
"""Run inference of ONNX model."""
|
|
318
|
+
logger.debug("start onnxruntime_inference")
|
|
319
|
+
|
|
320
|
+
ort_inputs = {
|
|
321
|
+
"input_ids": numpy.ascontiguousarray(inputs.decoder_input_ids.cpu().numpy()),
|
|
322
|
+
}
|
|
323
|
+
|
|
324
|
+
if inputs.past_key_values:
|
|
325
|
+
assert len(inputs.past_key_values) % 4 == 0
|
|
326
|
+
num_layers = int(len(inputs.past_key_values) / 4)
|
|
327
|
+
past_names = PastKeyValuesHelper.get_past_names(num_layers)
|
|
328
|
+
for i, past_tensor in enumerate(inputs.past_key_values):
|
|
329
|
+
ort_inputs[past_names[i]] = numpy.ascontiguousarray(past_tensor.cpu().numpy())
|
|
330
|
+
|
|
331
|
+
ort_outputs = ort_session.run(None, ort_inputs)
|
|
332
|
+
return ort_outputs
|
|
333
|
+
|
|
334
|
+
@staticmethod
|
|
335
|
+
def verify_onnx(
|
|
336
|
+
model: Union[WhisperDecoder, WhisperDecoderInit],
|
|
337
|
+
ort_session: InferenceSession,
|
|
338
|
+
device: torch.device,
|
|
339
|
+
use_int32_inputs: bool,
|
|
340
|
+
max_cases: int = 4,
|
|
341
|
+
):
|
|
342
|
+
"""Compare the result from PyTorch and OnnxRuntime to verify the ONNX model is good."""
|
|
343
|
+
float16: bool = TypeHelper.get_input_type(ort_session, "past_key_self_0") == "tensor(float16)"
|
|
344
|
+
|
|
345
|
+
test_cases = [(4, 11, 3), (1, 2, 5), (3, 1, 1), (8, 5, 2)]
|
|
346
|
+
test_cases_max_diff = []
|
|
347
|
+
for (
|
|
348
|
+
batch_size,
|
|
349
|
+
encode_sequence_length,
|
|
350
|
+
past_decode_sequence_length,
|
|
351
|
+
) in test_cases[:max_cases]:
|
|
352
|
+
if isinstance(model, WhisperDecoderInit):
|
|
353
|
+
dec_seq_len = 0
|
|
354
|
+
else:
|
|
355
|
+
dec_seq_len = past_decode_sequence_length
|
|
356
|
+
|
|
357
|
+
inputs = WhisperDecoderInputs.create_dummy(
|
|
358
|
+
model.config,
|
|
359
|
+
batch_size,
|
|
360
|
+
encode_sequence_length,
|
|
361
|
+
dec_seq_len,
|
|
362
|
+
device=device,
|
|
363
|
+
float16=float16,
|
|
364
|
+
use_int32_inputs=use_int32_inputs,
|
|
365
|
+
)
|
|
366
|
+
|
|
367
|
+
# We use fp32 PyTroch model as baseline even when ONNX model is fp16
|
|
368
|
+
input_list = inputs.to_fp32().to_list()
|
|
369
|
+
|
|
370
|
+
# Run inference of PyTorch model
|
|
371
|
+
with torch.no_grad():
|
|
372
|
+
torch_outputs = model(*input_list)
|
|
373
|
+
|
|
374
|
+
ort_outputs = WhisperDecoderHelper.onnxruntime_inference(ort_session, inputs)
|
|
375
|
+
|
|
376
|
+
max_diff = numpy.amax(numpy.abs(torch_outputs[0].cpu().numpy() - ort_outputs[0]))
|
|
377
|
+
max_diff_all = max_diff
|
|
378
|
+
logger.debug(f"logits max_diff={max_diff}")
|
|
379
|
+
|
|
380
|
+
for i in range(2 * model.config.num_layers):
|
|
381
|
+
max_diff = numpy.amax(numpy.abs(torch_outputs[1][i].cpu().numpy() - ort_outputs[1 + i]))
|
|
382
|
+
logger.debug(f"self attention past state {i} max_diff={max_diff}")
|
|
383
|
+
max_diff_all = max(max_diff_all, max_diff)
|
|
384
|
+
|
|
385
|
+
if isinstance(model, WhisperDecoderInit):
|
|
386
|
+
for i in range(2 * model.config.num_layers):
|
|
387
|
+
max_diff = numpy.amax(
|
|
388
|
+
numpy.abs(torch_outputs[2][i].cpu().numpy() - ort_outputs[1 + 2 * model.config.num_layers + i])
|
|
389
|
+
)
|
|
390
|
+
logger.debug(f"cross attention past state {i} max_diff={max_diff}")
|
|
391
|
+
max_diff_all = max(max_diff_all, max_diff)
|
|
392
|
+
|
|
393
|
+
test_cases_max_diff.append(max_diff_all)
|
|
394
|
+
logger.info(
|
|
395
|
+
"batch_size=%s, encode_sequence_length=%s, past_decode_sequence_length=%s, max_diff=%s",
|
|
396
|
+
batch_size,
|
|
397
|
+
encode_sequence_length,
|
|
398
|
+
past_decode_sequence_length,
|
|
399
|
+
max_diff_all,
|
|
400
|
+
)
|
|
401
|
+
|
|
402
|
+
return max_diff_all
|
|
@@ -0,0 +1,164 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License. See License.txt in the project root for
|
|
4
|
+
# license information.
|
|
5
|
+
# --------------------------------------------------------------------------
|
|
6
|
+
|
|
7
|
+
import logging
|
|
8
|
+
import os
|
|
9
|
+
import tempfile
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
from typing import List
|
|
12
|
+
|
|
13
|
+
import numpy
|
|
14
|
+
import onnx
|
|
15
|
+
import torch
|
|
16
|
+
from onnx_model import OnnxModel
|
|
17
|
+
from torch_onnx_export_helper import torch_onnx_export
|
|
18
|
+
from transformers import WhisperConfig
|
|
19
|
+
|
|
20
|
+
from onnxruntime import InferenceSession
|
|
21
|
+
|
|
22
|
+
logger = logging.getLogger(__name__)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class WhisperEncoder(torch.nn.Module):
|
|
26
|
+
"""Whisper encoder outputs only the last hidden state"""
|
|
27
|
+
|
|
28
|
+
def __init__(self, encoder, config: WhisperConfig, model_impl: str = "hf"):
|
|
29
|
+
super().__init__()
|
|
30
|
+
self.encoder = encoder
|
|
31
|
+
self.config = config
|
|
32
|
+
self.model_impl = model_impl
|
|
33
|
+
|
|
34
|
+
def forward(self, input_features):
|
|
35
|
+
if self.model_impl == "openai":
|
|
36
|
+
return self.encoder(input_features)
|
|
37
|
+
return self.encoder.model.encoder(input_features)[0]
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class WhisperEncoderInputs:
|
|
41
|
+
def __init__(self, input_features):
|
|
42
|
+
self.input_ids: torch.LongTensor = input_features
|
|
43
|
+
|
|
44
|
+
@staticmethod
|
|
45
|
+
def create_dummy(
|
|
46
|
+
batch_size: int,
|
|
47
|
+
sequence_length: int,
|
|
48
|
+
feature_size: int,
|
|
49
|
+
device: torch.device,
|
|
50
|
+
use_int32_inputs: bool = False,
|
|
51
|
+
):
|
|
52
|
+
"""Create dummy inputs for Whisper encoder.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
batch_size (int): batch size
|
|
56
|
+
sequence_length (int): sequence length
|
|
57
|
+
feature_size (int): feature size for spectrogram input
|
|
58
|
+
device (torch.device): device of output tensors
|
|
59
|
+
|
|
60
|
+
Returns:
|
|
61
|
+
WhisperEncoderInputs: dummy inputs for encoder
|
|
62
|
+
"""
|
|
63
|
+
|
|
64
|
+
input_features = torch.randn(
|
|
65
|
+
size=(batch_size, feature_size, sequence_length),
|
|
66
|
+
device=device,
|
|
67
|
+
)
|
|
68
|
+
return WhisperEncoderInputs(input_features)
|
|
69
|
+
|
|
70
|
+
def to_list(self) -> List:
|
|
71
|
+
if self.input_ids is None:
|
|
72
|
+
return []
|
|
73
|
+
return [self.input_ids]
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
class WhisperEncoderHelper:
|
|
77
|
+
@staticmethod
|
|
78
|
+
def export_onnx(
|
|
79
|
+
encoder,
|
|
80
|
+
device: torch.device,
|
|
81
|
+
onnx_model_path: str,
|
|
82
|
+
verbose: bool = True,
|
|
83
|
+
use_external_data_format: bool = False,
|
|
84
|
+
use_int32_inputs: bool = False,
|
|
85
|
+
):
|
|
86
|
+
"""Export encoder to ONNX
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
encoder (WhisperEncoder): encoder object
|
|
90
|
+
device (torch.device): device of encoder object
|
|
91
|
+
onnx_model_path (str): onnx path
|
|
92
|
+
verbose (bool, optional): print verbose information. Defaults to True.
|
|
93
|
+
use_external_data_format (bool, optional): use external data format or not. Defaults to False.
|
|
94
|
+
"""
|
|
95
|
+
config = encoder.config
|
|
96
|
+
encoder_inputs = WhisperEncoderInputs.create_dummy(
|
|
97
|
+
batch_size=2,
|
|
98
|
+
sequence_length=3000,
|
|
99
|
+
feature_size=config.num_mel_bins,
|
|
100
|
+
device=device,
|
|
101
|
+
use_int32_inputs=use_int32_inputs,
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
Path(onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
|
|
105
|
+
|
|
106
|
+
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
|
107
|
+
temp_onnx_model_path = os.path.join(tmp_dir_name, "encoder.onnx")
|
|
108
|
+
Path(temp_onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
|
|
109
|
+
torch_onnx_export(
|
|
110
|
+
encoder,
|
|
111
|
+
args=tuple(encoder_inputs.to_list()),
|
|
112
|
+
f=temp_onnx_model_path if use_external_data_format else onnx_model_path,
|
|
113
|
+
export_params=True,
|
|
114
|
+
input_names=["input_features"],
|
|
115
|
+
output_names=["hidden_states"],
|
|
116
|
+
dynamic_axes={
|
|
117
|
+
"input_ids": {0: "batch_size", 1: "feature_size", 2: "sequence_length"},
|
|
118
|
+
"hidden_states": {0: "batch_size", 1: "sequence_length"},
|
|
119
|
+
},
|
|
120
|
+
opset_version=17,
|
|
121
|
+
do_constant_folding=True,
|
|
122
|
+
use_external_data_format=use_external_data_format,
|
|
123
|
+
verbose=verbose,
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
if use_external_data_format:
|
|
127
|
+
model = onnx.load_model(temp_onnx_model_path, load_external_data=True)
|
|
128
|
+
OnnxModel.save(
|
|
129
|
+
model,
|
|
130
|
+
onnx_model_path,
|
|
131
|
+
save_as_external_data=True,
|
|
132
|
+
all_tensors_to_one_file=True,
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
@staticmethod
|
|
136
|
+
def onnxruntime_inference(ort_session, inputs: WhisperEncoderInputs):
|
|
137
|
+
"""Run inference of ONNX model."""
|
|
138
|
+
ort_inputs = {
|
|
139
|
+
"input_ids": numpy.ascontiguousarray(inputs.input_ids.cpu().numpy()),
|
|
140
|
+
}
|
|
141
|
+
|
|
142
|
+
return ort_session.run(None, ort_inputs)
|
|
143
|
+
|
|
144
|
+
@staticmethod
|
|
145
|
+
def verify_onnx(
|
|
146
|
+
model: WhisperEncoder, ort_session: InferenceSession, device: torch.device, use_int32_inputs: bool = False
|
|
147
|
+
):
|
|
148
|
+
"""Compare the result from PyTorch and OnnxRuntime to verify the ONNX model is good."""
|
|
149
|
+
inputs = WhisperEncoderInputs.create_dummy(
|
|
150
|
+
batch_size=4,
|
|
151
|
+
sequence_length=11,
|
|
152
|
+
device=device,
|
|
153
|
+
use_int32_inputs=use_int32_inputs,
|
|
154
|
+
)
|
|
155
|
+
input_list = inputs.to_list()
|
|
156
|
+
torch_outputs = model(*input_list)
|
|
157
|
+
|
|
158
|
+
ort_outputs = WhisperEncoderHelper.onnxruntime_inference(ort_session, inputs)
|
|
159
|
+
|
|
160
|
+
max_diff = numpy.amax(numpy.abs(torch_outputs.cpu().numpy() - ort_outputs[0]))
|
|
161
|
+
|
|
162
|
+
logger.info(f"max_diff={max_diff}")
|
|
163
|
+
|
|
164
|
+
return max_diff
|