onnxruntime-directml 1.20.0__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxruntime/LICENSE +21 -0
- onnxruntime/Privacy.md +21 -0
- onnxruntime/ThirdPartyNotices.txt +6508 -0
- onnxruntime/__init__.py +78 -0
- onnxruntime/backend/__init__.py +6 -0
- onnxruntime/backend/backend.py +174 -0
- onnxruntime/backend/backend_rep.py +53 -0
- onnxruntime/capi/DirectML.dll +0 -0
- onnxruntime/capi/__init__.py +4 -0
- onnxruntime/capi/_ld_preload.py +7 -0
- onnxruntime/capi/_pybind_state.py +33 -0
- onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
- onnxruntime/capi/onnxruntime.dll +0 -0
- onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
- onnxruntime/capi/onnxruntime_inference_collection.py +1108 -0
- onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
- onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
- onnxruntime/capi/onnxruntime_validation.py +150 -0
- onnxruntime/capi/version_info.py +2 -0
- onnxruntime/datasets/__init__.py +17 -0
- onnxruntime/datasets/logreg_iris.onnx +0 -0
- onnxruntime/datasets/mul_1.onnx +0 -0
- onnxruntime/datasets/sigmoid.onnx +13 -0
- onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
- onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
- onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
- onnxruntime/quantization/__init__.py +16 -0
- onnxruntime/quantization/base_quantizer.py +532 -0
- onnxruntime/quantization/calibrate.py +1245 -0
- onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
- onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
- onnxruntime/quantization/execution_providers/qnn/preprocess.py +307 -0
- onnxruntime/quantization/execution_providers/qnn/quant_config.py +387 -0
- onnxruntime/quantization/fusions/__init__.py +3 -0
- onnxruntime/quantization/fusions/fusion.py +311 -0
- onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
- onnxruntime/quantization/fusions/fusion_layernorm.py +135 -0
- onnxruntime/quantization/matmul_4bits_quantizer.py +1480 -0
- onnxruntime/quantization/matmul_bnb4_quantizer.py +240 -0
- onnxruntime/quantization/onnx_model.py +580 -0
- onnxruntime/quantization/onnx_quantizer.py +1008 -0
- onnxruntime/quantization/operators/__init__.py +2 -0
- onnxruntime/quantization/operators/activation.py +119 -0
- onnxruntime/quantization/operators/argmax.py +18 -0
- onnxruntime/quantization/operators/attention.py +73 -0
- onnxruntime/quantization/operators/base_operator.py +26 -0
- onnxruntime/quantization/operators/binary_op.py +72 -0
- onnxruntime/quantization/operators/concat.py +62 -0
- onnxruntime/quantization/operators/conv.py +258 -0
- onnxruntime/quantization/operators/direct_q8.py +78 -0
- onnxruntime/quantization/operators/embed_layernorm.py +121 -0
- onnxruntime/quantization/operators/gather.py +64 -0
- onnxruntime/quantization/operators/gavgpool.py +62 -0
- onnxruntime/quantization/operators/gemm.py +166 -0
- onnxruntime/quantization/operators/lstm.py +117 -0
- onnxruntime/quantization/operators/matmul.py +231 -0
- onnxruntime/quantization/operators/maxpool.py +34 -0
- onnxruntime/quantization/operators/norm.py +40 -0
- onnxruntime/quantization/operators/pad.py +100 -0
- onnxruntime/quantization/operators/pooling.py +67 -0
- onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
- onnxruntime/quantization/operators/resize.py +34 -0
- onnxruntime/quantization/operators/softmax.py +74 -0
- onnxruntime/quantization/operators/split.py +63 -0
- onnxruntime/quantization/operators/where.py +87 -0
- onnxruntime/quantization/preprocess.py +141 -0
- onnxruntime/quantization/qdq_loss_debug.py +389 -0
- onnxruntime/quantization/qdq_quantizer.py +1187 -0
- onnxruntime/quantization/quant_utils.py +891 -0
- onnxruntime/quantization/quantize.py +748 -0
- onnxruntime/quantization/registry.py +106 -0
- onnxruntime/quantization/shape_inference.py +187 -0
- onnxruntime/quantization/tensor_quant_overrides.py +516 -0
- onnxruntime/tools/__init__.py +10 -0
- onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
- onnxruntime/tools/convert_onnx_models_to_ort.py +377 -0
- onnxruntime/tools/file_utils.py +46 -0
- onnxruntime/tools/logger.py +11 -0
- onnxruntime/tools/make_dynamic_shape_fixed.py +72 -0
- onnxruntime/tools/mobile_helpers/__init__.py +0 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +33 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
- onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
- onnxruntime/tools/mobile_helpers/usability_checker.py +739 -0
- onnxruntime/tools/offline_tuning.py +169 -0
- onnxruntime/tools/onnx_model_utils.py +413 -0
- onnxruntime/tools/onnx_randomizer.py +85 -0
- onnxruntime/tools/onnxruntime_test.py +164 -0
- onnxruntime/tools/optimize_onnx_model.py +55 -0
- onnxruntime/tools/ort_format_model/__init__.py +25 -0
- onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +663 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
- onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
- onnxruntime/tools/ort_format_model/types.py +84 -0
- onnxruntime/tools/ort_format_model/utils.py +62 -0
- onnxruntime/tools/pytorch_export_contrib_ops.py +108 -0
- onnxruntime/tools/pytorch_export_helpers.py +131 -0
- onnxruntime/tools/qdq_helpers/__init__.py +0 -0
- onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
- onnxruntime/tools/reduced_build_config_parser.py +202 -0
- onnxruntime/tools/symbolic_shape_infer.py +3016 -0
- onnxruntime/tools/update_onnx_opset.py +31 -0
- onnxruntime/transformers/__init__.py +8 -0
- onnxruntime/transformers/affinity_helper.py +40 -0
- onnxruntime/transformers/benchmark.py +944 -0
- onnxruntime/transformers/benchmark_helper.py +646 -0
- onnxruntime/transformers/bert_perf_test.py +634 -0
- onnxruntime/transformers/bert_test_data.py +642 -0
- onnxruntime/transformers/compare_bert_results.py +246 -0
- onnxruntime/transformers/constants.py +47 -0
- onnxruntime/transformers/convert_generation.py +3124 -0
- onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
- onnxruntime/transformers/convert_to_packing_mode.py +387 -0
- onnxruntime/transformers/dynamo_onnx_helper.py +104 -0
- onnxruntime/transformers/float16.py +501 -0
- onnxruntime/transformers/fusion_attention.py +1235 -0
- onnxruntime/transformers/fusion_attention_clip.py +257 -0
- onnxruntime/transformers/fusion_attention_sam2.py +534 -0
- onnxruntime/transformers/fusion_attention_unet.py +1304 -0
- onnxruntime/transformers/fusion_attention_vae.py +301 -0
- onnxruntime/transformers/fusion_bart_attention.py +640 -0
- onnxruntime/transformers/fusion_base.py +137 -0
- onnxruntime/transformers/fusion_bias_add.py +58 -0
- onnxruntime/transformers/fusion_biasgelu.py +66 -0
- onnxruntime/transformers/fusion_biassplitgelu.py +111 -0
- onnxruntime/transformers/fusion_conformer_attention.py +143 -0
- onnxruntime/transformers/fusion_embedlayer.py +811 -0
- onnxruntime/transformers/fusion_fastgelu.py +360 -0
- onnxruntime/transformers/fusion_gelu.py +259 -0
- onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
- onnxruntime/transformers/fusion_gemmfastgelu.py +122 -0
- onnxruntime/transformers/fusion_gpt_attention.py +546 -0
- onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
- onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
- onnxruntime/transformers/fusion_group_norm.py +179 -0
- onnxruntime/transformers/fusion_layernorm.py +465 -0
- onnxruntime/transformers/fusion_nhwc_conv.py +100 -0
- onnxruntime/transformers/fusion_options.py +340 -0
- onnxruntime/transformers/fusion_qordered_attention.py +421 -0
- onnxruntime/transformers/fusion_qordered_gelu.py +119 -0
- onnxruntime/transformers/fusion_qordered_layernorm.py +123 -0
- onnxruntime/transformers/fusion_qordered_matmul.py +217 -0
- onnxruntime/transformers/fusion_quickgelu.py +74 -0
- onnxruntime/transformers/fusion_reshape.py +173 -0
- onnxruntime/transformers/fusion_rotary_attention.py +1592 -0
- onnxruntime/transformers/fusion_shape.py +110 -0
- onnxruntime/transformers/fusion_simplified_layernorm.py +159 -0
- onnxruntime/transformers/fusion_skip_group_norm.py +255 -0
- onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
- onnxruntime/transformers/fusion_transpose.py +168 -0
- onnxruntime/transformers/fusion_utils.py +307 -0
- onnxruntime/transformers/huggingface_models.py +167 -0
- onnxruntime/transformers/import_utils.py +20 -0
- onnxruntime/transformers/io_binding_helper.py +442 -0
- onnxruntime/transformers/large_model_exporter.py +395 -0
- onnxruntime/transformers/machine_info.py +221 -0
- onnxruntime/transformers/metrics.py +164 -0
- onnxruntime/transformers/models/bart/__init__.py +12 -0
- onnxruntime/transformers/models/bart/export.py +98 -0
- onnxruntime/transformers/models/bert/__init__.py +12 -0
- onnxruntime/transformers/models/bert/eval_squad.py +329 -0
- onnxruntime/transformers/models/gpt2/__init__.py +12 -0
- onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
- onnxruntime/transformers/models/gpt2/convert_to_onnx.py +561 -0
- onnxruntime/transformers/models/gpt2/gpt2_helper.py +1032 -0
- onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
- onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
- onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
- onnxruntime/transformers/models/llama/__init__.py +12 -0
- onnxruntime/transformers/models/llama/benchmark.py +703 -0
- onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
- onnxruntime/transformers/models/llama/benchmark_e2e.py +606 -0
- onnxruntime/transformers/models/llama/convert_to_onnx.py +1027 -0
- onnxruntime/transformers/models/llama/dist_settings.py +57 -0
- onnxruntime/transformers/models/llama/llama_inputs.py +503 -0
- onnxruntime/transformers/models/llama/llama_parity.py +309 -0
- onnxruntime/transformers/models/llama/llama_torch.py +47 -0
- onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
- onnxruntime/transformers/models/longformer/__init__.py +12 -0
- onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
- onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
- onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
- onnxruntime/transformers/models/longformer/longformer_helper.py +77 -0
- onnxruntime/transformers/models/phi2/__init__.py +12 -0
- onnxruntime/transformers/models/phi2/convert_to_onnx.py +576 -0
- onnxruntime/transformers/models/phi2/inference_example.py +414 -0
- onnxruntime/transformers/models/sam2/__init__.py +12 -0
- onnxruntime/transformers/models/sam2/benchmark_sam2.py +625 -0
- onnxruntime/transformers/models/sam2/convert_to_onnx.py +260 -0
- onnxruntime/transformers/models/sam2/image_decoder.py +273 -0
- onnxruntime/transformers/models/sam2/image_encoder.py +186 -0
- onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
- onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
- onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
- onnxruntime/transformers/models/sam2/sam2_demo.py +322 -0
- onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +280 -0
- onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
- onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark.py +1429 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +102 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +268 -0
- onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1319 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1181 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder.py +296 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +388 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
- onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +350 -0
- onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
- onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
- onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
- onnxruntime/transformers/models/t5/__init__.py +12 -0
- onnxruntime/transformers/models/t5/convert_to_onnx.py +278 -0
- onnxruntime/transformers/models/t5/past_helper.py +150 -0
- onnxruntime/transformers/models/t5/t5_decoder.py +438 -0
- onnxruntime/transformers/models/t5/t5_encoder.py +171 -0
- onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +299 -0
- onnxruntime/transformers/models/t5/t5_helper.py +272 -0
- onnxruntime/transformers/models/whisper/__init__.py +12 -0
- onnxruntime/transformers/models/whisper/benchmark.py +610 -0
- onnxruntime/transformers/models/whisper/benchmark_all.py +528 -0
- onnxruntime/transformers/models/whisper/convert_to_onnx.py +536 -0
- onnxruntime/transformers/models/whisper/whisper_chain.py +329 -0
- onnxruntime/transformers/models/whisper/whisper_decoder.py +402 -0
- onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
- onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +306 -0
- onnxruntime/transformers/models/whisper/whisper_helper.py +524 -0
- onnxruntime/transformers/models/whisper/whisper_openai_helper.py +84 -0
- onnxruntime/transformers/onnx_exporter.py +717 -0
- onnxruntime/transformers/onnx_model.py +1569 -0
- onnxruntime/transformers/onnx_model_bart.py +142 -0
- onnxruntime/transformers/onnx_model_bert.py +481 -0
- onnxruntime/transformers/onnx_model_bert_keras.py +475 -0
- onnxruntime/transformers/onnx_model_bert_tf.py +589 -0
- onnxruntime/transformers/onnx_model_clip.py +40 -0
- onnxruntime/transformers/onnx_model_conformer.py +33 -0
- onnxruntime/transformers/onnx_model_gpt2.py +101 -0
- onnxruntime/transformers/onnx_model_phi.py +930 -0
- onnxruntime/transformers/onnx_model_sam2.py +138 -0
- onnxruntime/transformers/onnx_model_t5.py +791 -0
- onnxruntime/transformers/onnx_model_tnlr.py +227 -0
- onnxruntime/transformers/onnx_model_unet.py +259 -0
- onnxruntime/transformers/onnx_model_vae.py +43 -0
- onnxruntime/transformers/onnx_utils.py +55 -0
- onnxruntime/transformers/optimizer.py +612 -0
- onnxruntime/transformers/profiler.py +725 -0
- onnxruntime/transformers/quantize_helper.py +76 -0
- onnxruntime/transformers/shape_infer_helper.py +122 -0
- onnxruntime/transformers/shape_optimizer.py +401 -0
- onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
- onnxruntime_directml-1.20.0.dist-info/METADATA +187 -0
- onnxruntime_directml-1.20.0.dist-info/RECORD +305 -0
- onnxruntime_directml-1.20.0.dist-info/WHEEL +5 -0
- onnxruntime_directml-1.20.0.dist-info/entry_points.txt +2 -0
- onnxruntime_directml-1.20.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,524 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License. See License.txt in the project root for
|
|
4
|
+
# license information.
|
|
5
|
+
# --------------------------------------------------------------------------
|
|
6
|
+
|
|
7
|
+
import logging
|
|
8
|
+
import os
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
from typing import Dict, Tuple, Union
|
|
11
|
+
|
|
12
|
+
import numpy as np
|
|
13
|
+
import torch
|
|
14
|
+
from float16 import float_to_float16_max_diff
|
|
15
|
+
from onnx_model import OnnxModel
|
|
16
|
+
from optimizer import optimize_model
|
|
17
|
+
from packaging import version
|
|
18
|
+
from transformers import WhisperConfig, WhisperForConditionalGeneration, WhisperProcessor
|
|
19
|
+
from transformers import __version__ as transformers_version
|
|
20
|
+
from whisper_decoder import WhisperDecoder, WhisperDecoderHelper, WhisperDecoderInit
|
|
21
|
+
from whisper_encoder import WhisperEncoder, WhisperEncoderHelper
|
|
22
|
+
from whisper_encoder_decoder_init import WhisperEncoderDecoderInit, WhisperEncoderDecoderInitHelper
|
|
23
|
+
|
|
24
|
+
from onnxruntime import InferenceSession
|
|
25
|
+
|
|
26
|
+
logger = logging.getLogger(__name__)
|
|
27
|
+
|
|
28
|
+
PRETRAINED_WHISPER_MODELS = [
|
|
29
|
+
"whisper-tiny",
|
|
30
|
+
"whisper-tiny.en",
|
|
31
|
+
"whisper-base",
|
|
32
|
+
"whisper-base.en",
|
|
33
|
+
"whisper-small",
|
|
34
|
+
"whisper-small.en",
|
|
35
|
+
"whisper-medium",
|
|
36
|
+
"whisper-medium.en",
|
|
37
|
+
"whisper-large",
|
|
38
|
+
"whisper-large-v2",
|
|
39
|
+
"whisper-large-v3",
|
|
40
|
+
]
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class WhisperHelper:
|
|
44
|
+
@staticmethod
|
|
45
|
+
def get_onnx_path(
|
|
46
|
+
output_dir: str,
|
|
47
|
+
model_name_or_path: str,
|
|
48
|
+
suffix: str = "",
|
|
49
|
+
new_folder: bool = False,
|
|
50
|
+
) -> str:
|
|
51
|
+
"""Build onnx path
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
output_dir (str): output directory
|
|
55
|
+
model_name_or_path (str): pretrained model name, or path to the model checkpoint
|
|
56
|
+
suffix (str, optional): suffix like "_encoder" or "_decoder_fp16" will be appended to file name. Defaults to None.
|
|
57
|
+
new_folder (bool, optional): create a new directory for the model. Defaults to False.
|
|
58
|
+
|
|
59
|
+
Returns:
|
|
60
|
+
str: path of onnx model
|
|
61
|
+
"""
|
|
62
|
+
model_name = model_name_or_path
|
|
63
|
+
if os.path.isdir(model_name_or_path):
|
|
64
|
+
model_name = Path(model_name_or_path).parts[-1]
|
|
65
|
+
else:
|
|
66
|
+
model_name = model_name.split("/")[-1]
|
|
67
|
+
|
|
68
|
+
model_name += suffix
|
|
69
|
+
|
|
70
|
+
directory = os.path.join(output_dir, model_name) if new_folder else output_dir
|
|
71
|
+
return os.path.join(directory, model_name + ".onnx")
|
|
72
|
+
|
|
73
|
+
@staticmethod
|
|
74
|
+
def load_model_openai(
|
|
75
|
+
model_name_or_path: str,
|
|
76
|
+
cache_dir: str,
|
|
77
|
+
device: torch.device,
|
|
78
|
+
) -> torch.nn.Module:
|
|
79
|
+
"""Load model given a pretrained name or path, then build models for ONNX conversion.
|
|
80
|
+
|
|
81
|
+
Args:
|
|
82
|
+
model_name_or_path (str): pretrained model name or path
|
|
83
|
+
cache_dir (str): cache directory
|
|
84
|
+
device (torch.device): device to run the model
|
|
85
|
+
merge_encoder_and_decoder_init (bool, optional): Whether merge encoder and decoder initialization into one ONNX model. Defaults to True.
|
|
86
|
+
Returns:
|
|
87
|
+
Dict[str, torch.nn.Module]: mapping from name to modules for ONNX conversion.
|
|
88
|
+
"""
|
|
89
|
+
from whisper import _ALIGNMENT_HEADS, _MODELS, _download
|
|
90
|
+
from whisper.model import ModelDimensions, Whisper
|
|
91
|
+
|
|
92
|
+
in_memory = False
|
|
93
|
+
|
|
94
|
+
model_name = model_name_or_path.split("/")[-1][8:]
|
|
95
|
+
checkpoint_file, alignment_heads = None, None
|
|
96
|
+
if model_name in _MODELS:
|
|
97
|
+
checkpoint_file = _download(_MODELS[model_name], cache_dir, in_memory)
|
|
98
|
+
alignment_heads = _ALIGNMENT_HEADS[model_name]
|
|
99
|
+
|
|
100
|
+
with open(checkpoint_file, "rb") as fp:
|
|
101
|
+
checkpoint = torch.load(fp, map_location=device)
|
|
102
|
+
del checkpoint_file
|
|
103
|
+
|
|
104
|
+
dims = ModelDimensions(**checkpoint["dims"])
|
|
105
|
+
model = Whisper(dims)
|
|
106
|
+
model.load_state_dict(checkpoint["model_state_dict"])
|
|
107
|
+
|
|
108
|
+
if alignment_heads is not None:
|
|
109
|
+
model.set_alignment_heads(alignment_heads)
|
|
110
|
+
return model.to(device)
|
|
111
|
+
|
|
112
|
+
@staticmethod
|
|
113
|
+
def load_model(
|
|
114
|
+
model_name_or_path: str,
|
|
115
|
+
model_impl: str,
|
|
116
|
+
cache_dir: str,
|
|
117
|
+
device: torch.device,
|
|
118
|
+
merge_encoder_and_decoder_init: bool = True,
|
|
119
|
+
state_dict_path: str = "",
|
|
120
|
+
) -> Dict[str, torch.nn.Module]:
|
|
121
|
+
"""Load model given a pretrained name or path, then build models for ONNX conversion.
|
|
122
|
+
|
|
123
|
+
Args:
|
|
124
|
+
model_name_or_path (str): pretrained model name or path
|
|
125
|
+
cache_dir (str): cache directory
|
|
126
|
+
device (torch.device): device to run the model
|
|
127
|
+
merge_encoder_and_decoder_init (bool, optional): Whether merge encoder and decoder initialization into one ONNX model. Defaults to True.
|
|
128
|
+
Returns:
|
|
129
|
+
Dict[str, torch.nn.Module]: mapping from name to modules for ONNX conversion.
|
|
130
|
+
"""
|
|
131
|
+
extra_kwargs = {}
|
|
132
|
+
if version.parse(transformers_version) >= version.parse("4.36.0"):
|
|
133
|
+
extra_kwargs["attn_implementation"] = "eager"
|
|
134
|
+
model = WhisperForConditionalGeneration.from_pretrained(model_name_or_path, cache_dir=cache_dir, **extra_kwargs)
|
|
135
|
+
|
|
136
|
+
if model_impl == "openai":
|
|
137
|
+
openai_model = WhisperHelper.load_model_openai(model_name_or_path, cache_dir, device)
|
|
138
|
+
model_encoder, model_decoder = openai_model.encoder, openai_model.decoder
|
|
139
|
+
passed_model = openai_model
|
|
140
|
+
else:
|
|
141
|
+
model_encoder, model_decoder = model, model
|
|
142
|
+
passed_model = None
|
|
143
|
+
|
|
144
|
+
if state_dict_path:
|
|
145
|
+
model.load_state_dict(torch.load(state_dict_path), strict=False)
|
|
146
|
+
|
|
147
|
+
decoder = WhisperDecoder(model_decoder, model.config, model_impl=model_impl, model=passed_model)
|
|
148
|
+
decoder.eval().to(device)
|
|
149
|
+
|
|
150
|
+
if merge_encoder_and_decoder_init:
|
|
151
|
+
encoder_decoder_init = WhisperEncoderDecoderInit(
|
|
152
|
+
model_encoder,
|
|
153
|
+
model_decoder,
|
|
154
|
+
model.config,
|
|
155
|
+
decoder_start_token_id=None,
|
|
156
|
+
model_impl=model_impl,
|
|
157
|
+
model=passed_model,
|
|
158
|
+
)
|
|
159
|
+
return {"encoder_decoder_init": encoder_decoder_init, "decoder": decoder}
|
|
160
|
+
else:
|
|
161
|
+
encoder = WhisperEncoder(model.model.encoder, model.config)
|
|
162
|
+
encoder.eval().to(device)
|
|
163
|
+
decoder_init = WhisperDecoderInit(model.decoder, model.config)
|
|
164
|
+
decoder_init.eval().to(device)
|
|
165
|
+
return {
|
|
166
|
+
"encoder": encoder,
|
|
167
|
+
"decoder": decoder,
|
|
168
|
+
"decoder_init": decoder_init,
|
|
169
|
+
}
|
|
170
|
+
|
|
171
|
+
@staticmethod
|
|
172
|
+
def export_onnx(
|
|
173
|
+
model: Union[WhisperEncoder, WhisperDecoder, WhisperDecoderInit, WhisperEncoderDecoderInit],
|
|
174
|
+
device: torch.device,
|
|
175
|
+
onnx_model_path: str,
|
|
176
|
+
verbose: bool = True,
|
|
177
|
+
use_external_data_format: bool = False,
|
|
178
|
+
use_decoder_input_ids: bool = True,
|
|
179
|
+
use_int32_inputs: bool = False,
|
|
180
|
+
):
|
|
181
|
+
if isinstance(model, WhisperEncoder):
|
|
182
|
+
WhisperEncoderHelper.export_onnx(
|
|
183
|
+
model,
|
|
184
|
+
device,
|
|
185
|
+
onnx_model_path,
|
|
186
|
+
verbose,
|
|
187
|
+
use_external_data_format,
|
|
188
|
+
)
|
|
189
|
+
elif isinstance(model, WhisperEncoderDecoderInit):
|
|
190
|
+
WhisperEncoderDecoderInitHelper.export_onnx(
|
|
191
|
+
model,
|
|
192
|
+
device,
|
|
193
|
+
onnx_model_path,
|
|
194
|
+
use_decoder_input_ids,
|
|
195
|
+
verbose,
|
|
196
|
+
use_external_data_format,
|
|
197
|
+
use_int32_inputs,
|
|
198
|
+
)
|
|
199
|
+
else:
|
|
200
|
+
WhisperDecoderHelper.export_onnx(
|
|
201
|
+
model,
|
|
202
|
+
device,
|
|
203
|
+
onnx_model_path,
|
|
204
|
+
verbose,
|
|
205
|
+
use_external_data_format,
|
|
206
|
+
use_int32_inputs,
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
@staticmethod
|
|
210
|
+
def auto_mixed_precision(
|
|
211
|
+
onnx_model: OnnxModel,
|
|
212
|
+
op_block_list: Tuple[str] = (
|
|
213
|
+
"SimplifiedLayerNormalization",
|
|
214
|
+
"SkipSimplifiedLayerNormalization",
|
|
215
|
+
"Relu",
|
|
216
|
+
"Add",
|
|
217
|
+
),
|
|
218
|
+
):
|
|
219
|
+
"""Convert model to mixed precision.
|
|
220
|
+
It detects whether original model has fp16 precision weights, and set parameters for float16 conversion automatically.
|
|
221
|
+
Args:
|
|
222
|
+
onnx_model (OnnxModel): optimized ONNX model
|
|
223
|
+
op_block_list (List[str], optional): . Defaults to ["SimplifiedLayerNormalization", "SkipSimplifiedLayerNormalization", "Relu", "Add"]
|
|
224
|
+
Returns:
|
|
225
|
+
parameters(dict): a dictionary of parameters used in float16 conversion
|
|
226
|
+
"""
|
|
227
|
+
op_full_set = set([node.op_type for node in onnx_model.nodes()])
|
|
228
|
+
fp32_op_set = set(op_block_list)
|
|
229
|
+
fp16_op_set = op_full_set.difference(fp32_op_set)
|
|
230
|
+
logger.info(f"fp32 op: {fp32_op_set} fp16 op: {fp16_op_set}")
|
|
231
|
+
|
|
232
|
+
# logits is the first output
|
|
233
|
+
logits_output_name = onnx_model.graph().output[0].name
|
|
234
|
+
|
|
235
|
+
# We use the weight in last MatMul node to detect whether the model is stored with float16 weights from training.
|
|
236
|
+
is_weight_fp16_precision = False
|
|
237
|
+
output_name_to_node = onnx_model.output_name_to_node()
|
|
238
|
+
assert logits_output_name in output_name_to_node
|
|
239
|
+
node = output_name_to_node[logits_output_name]
|
|
240
|
+
last_matmul_node = None
|
|
241
|
+
if node.op_type == "MatMul":
|
|
242
|
+
last_matmul_node = node
|
|
243
|
+
logger.info(f"Found last MatMul node for logits: {node.name}")
|
|
244
|
+
initializer = None
|
|
245
|
+
for input in node.input:
|
|
246
|
+
initializer = onnx_model.get_initializer(input)
|
|
247
|
+
if initializer is not None:
|
|
248
|
+
break
|
|
249
|
+
|
|
250
|
+
# when the max difference of value after converting float to float16 is lower than a threshold (1e-6),
|
|
251
|
+
# we can deduce that the weights are stored in float16 precision.
|
|
252
|
+
max_diff = float_to_float16_max_diff(initializer)
|
|
253
|
+
logger.debug(f"max diff of converting weights in last MatMul node {node.name}: {max_diff}")
|
|
254
|
+
is_weight_fp16_precision = max_diff < 1e-6
|
|
255
|
+
else:
|
|
256
|
+
logger.warning(f"Failed to find MatMul node for logits. Found {node.op_type} of node {node.name}")
|
|
257
|
+
|
|
258
|
+
keep_io_types = []
|
|
259
|
+
node_block_list = []
|
|
260
|
+
if (not is_weight_fp16_precision) and (last_matmul_node is not None):
|
|
261
|
+
# When original weight is float32 precision, keep logits and last MatMul in float32 could get better precision.
|
|
262
|
+
keep_io_types = [logits_output_name]
|
|
263
|
+
node_block_list = [last_matmul_node.name]
|
|
264
|
+
|
|
265
|
+
parameters = {
|
|
266
|
+
"keep_io_types": keep_io_types,
|
|
267
|
+
"op_block_list": list(op_block_list),
|
|
268
|
+
"node_block_list": node_block_list,
|
|
269
|
+
"force_fp16_initializers": is_weight_fp16_precision,
|
|
270
|
+
}
|
|
271
|
+
|
|
272
|
+
logger.info(f"auto_mixed_precision parameters: {parameters}")
|
|
273
|
+
onnx_model.convert_float_to_float16(use_symbolic_shape_infer=True, **parameters)
|
|
274
|
+
|
|
275
|
+
return parameters
|
|
276
|
+
|
|
277
|
+
@staticmethod
|
|
278
|
+
def optimize_onnx(
|
|
279
|
+
onnx_model_path: str,
|
|
280
|
+
optimized_model_path: str,
|
|
281
|
+
is_float16: bool,
|
|
282
|
+
num_attention_heads: int,
|
|
283
|
+
hidden_size: int,
|
|
284
|
+
use_external_data_format: bool = False,
|
|
285
|
+
auto_mixed_precision: bool = True,
|
|
286
|
+
use_gpu: bool = False,
|
|
287
|
+
provider: str = "cpu",
|
|
288
|
+
):
|
|
289
|
+
"""Optimize ONNX model with an option to convert it to use mixed precision."""
|
|
290
|
+
|
|
291
|
+
from fusion_options import FusionOptions
|
|
292
|
+
|
|
293
|
+
optimization_options = FusionOptions("bart")
|
|
294
|
+
optimization_options.use_multi_head_attention = True
|
|
295
|
+
optimization_options.disable_multi_head_attention_bias = provider == "rocm"
|
|
296
|
+
|
|
297
|
+
m = optimize_model(
|
|
298
|
+
onnx_model_path,
|
|
299
|
+
model_type="bart",
|
|
300
|
+
num_heads=num_attention_heads,
|
|
301
|
+
hidden_size=hidden_size,
|
|
302
|
+
opt_level=2 if not use_external_data_format else None,
|
|
303
|
+
optimization_options=optimization_options,
|
|
304
|
+
use_gpu=use_gpu,
|
|
305
|
+
only_onnxruntime=False,
|
|
306
|
+
)
|
|
307
|
+
|
|
308
|
+
if is_float16:
|
|
309
|
+
if auto_mixed_precision:
|
|
310
|
+
WhisperHelper.auto_mixed_precision(m)
|
|
311
|
+
else:
|
|
312
|
+
m.convert_model_float32_to_float16(cast_input_output=False)
|
|
313
|
+
|
|
314
|
+
m.save_model_to_file(optimized_model_path, use_external_data_format, all_tensors_to_one_file=True)
|
|
315
|
+
|
|
316
|
+
@staticmethod
|
|
317
|
+
def pt_transcription_for_verify_onnx(
|
|
318
|
+
processor: WhisperProcessor,
|
|
319
|
+
pt_model: torch.nn.Module,
|
|
320
|
+
device: torch.device,
|
|
321
|
+
batch_size: int = 1,
|
|
322
|
+
prompt_mode: bool = False,
|
|
323
|
+
):
|
|
324
|
+
# Try to import `datasets` pip package
|
|
325
|
+
try:
|
|
326
|
+
from datasets import load_dataset
|
|
327
|
+
except Exception as e:
|
|
328
|
+
logger.error(f"An error occurred while importing `datasets`: {e}", exc_info=True) # noqa: G201
|
|
329
|
+
install_cmd = "pip install datasets"
|
|
330
|
+
logger.warning(f"Could not import `datasets`. Attempting to install `datasets` via `{install_cmd}`.")
|
|
331
|
+
os.system(install_cmd)
|
|
332
|
+
|
|
333
|
+
from datasets import load_dataset
|
|
334
|
+
|
|
335
|
+
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
|
336
|
+
input_features_ = []
|
|
337
|
+
if batch_size == 1:
|
|
338
|
+
input_features = processor([ds[0]["audio"]["array"]], return_tensors="pt").input_features
|
|
339
|
+
else:
|
|
340
|
+
input_features_ = [
|
|
341
|
+
processor([ds[3]["audio"]["array"]], return_tensors="pt").input_features,
|
|
342
|
+
processor([ds[3]["audio"]["array"]], return_tensors="pt").input_features,
|
|
343
|
+
]
|
|
344
|
+
assert len(input_features_) == batch_size
|
|
345
|
+
input_features = torch.cat((input_features_[0], input_features_[1]))
|
|
346
|
+
|
|
347
|
+
max_length, min_length, num_beams, num_return_sequences = 30, 0, 1, 1
|
|
348
|
+
length_penalty, repetition_penalty = 1.0, 1.0
|
|
349
|
+
inputs = {
|
|
350
|
+
"input_features": input_features.to(device),
|
|
351
|
+
"max_length": max_length,
|
|
352
|
+
"min_length": min_length,
|
|
353
|
+
"num_beams": num_beams,
|
|
354
|
+
"num_return_sequences": num_return_sequences,
|
|
355
|
+
"length_penalty": length_penalty,
|
|
356
|
+
"repetition_penalty": repetition_penalty,
|
|
357
|
+
"early_stopping": True,
|
|
358
|
+
"use_cache": True,
|
|
359
|
+
}
|
|
360
|
+
|
|
361
|
+
if prompt_mode:
|
|
362
|
+
prompts = ["John has doubts", "Maria has grave doubts"]
|
|
363
|
+
prompt_ids = [processor.get_prompt_ids(p) for p in prompts]
|
|
364
|
+
pt_transcription = []
|
|
365
|
+
pt_outputs = []
|
|
366
|
+
# The looping for model.generate is necessary here due to the limitation as per
|
|
367
|
+
# https://huggingface.co/docs/transformers/model_doc/whisper#transformers.WhisperForConditionalGeneration.generate.prompt_ids
|
|
368
|
+
# prompt_ids input requires a tensor of rank 1
|
|
369
|
+
for i in range(batch_size):
|
|
370
|
+
inputs["prompt_ids"] = torch.from_numpy(prompt_ids[i])
|
|
371
|
+
inputs["input_features"] = input_features_[i].to(device)
|
|
372
|
+
pt_output = pt_model.generate(**inputs).detach().cpu().numpy()
|
|
373
|
+
pt_outputs.append(pt_output)
|
|
374
|
+
pt_transcription.append(processor.batch_decode(pt_output, skip_special_tokens=True)[0])
|
|
375
|
+
inputs["input_features"] = input_features
|
|
376
|
+
del inputs["prompt_ids"]
|
|
377
|
+
else:
|
|
378
|
+
prompt_ids = []
|
|
379
|
+
pt_outputs = pt_model.generate(**inputs).detach().cpu().numpy()
|
|
380
|
+
pt_transcription = [processor.batch_decode(pt_outputs, skip_special_tokens=True)[0]]
|
|
381
|
+
pt_outputs = list(pt_outputs)
|
|
382
|
+
del inputs["early_stopping"]
|
|
383
|
+
del inputs["use_cache"]
|
|
384
|
+
return inputs, pt_transcription, pt_outputs, prompt_ids
|
|
385
|
+
|
|
386
|
+
@staticmethod
|
|
387
|
+
def select_transcription_options(
|
|
388
|
+
batch_size: int,
|
|
389
|
+
prompt_mode: bool,
|
|
390
|
+
):
|
|
391
|
+
if batch_size > 1 and prompt_mode:
|
|
392
|
+
expected_transcription_no_comma_prompt1 = " John has doubts whether Sir Frederick Layton's work is really Greek after all and can discover in it but little of Rocky I"
|
|
393
|
+
expected_transcription_misspelled_prompt1 = " John has doubts whether Sir Frederick Latins work is really Greek after all and can discover in it but little of Rocky I"
|
|
394
|
+
expected_transcription_no_comma_prompt2 = " Maria has grave doubts whether Sir Frederick Layton's work is really Greek after all and can discover in it but little of Rocky"
|
|
395
|
+
expected_transcription_misspelled_prompt2 = " Maria has grave doubts whether Sir Frederick Latins work is really Greek after all and can discover in it but little of Rocky I"
|
|
396
|
+
expected_transcription_options = {
|
|
397
|
+
expected_transcription_no_comma_prompt1,
|
|
398
|
+
expected_transcription_no_comma_prompt2,
|
|
399
|
+
expected_transcription_misspelled_prompt1,
|
|
400
|
+
expected_transcription_misspelled_prompt2,
|
|
401
|
+
}
|
|
402
|
+
else:
|
|
403
|
+
expected_transcription_no_comma = (
|
|
404
|
+
" Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel."
|
|
405
|
+
)
|
|
406
|
+
expected_transcription_with_comma = (
|
|
407
|
+
" Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel."
|
|
408
|
+
)
|
|
409
|
+
expected_transcription_with_quote_and_comma = (
|
|
410
|
+
' "Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.'
|
|
411
|
+
)
|
|
412
|
+
expected_transcription_options = {
|
|
413
|
+
expected_transcription_no_comma,
|
|
414
|
+
expected_transcription_with_comma,
|
|
415
|
+
expected_transcription_with_quote_and_comma,
|
|
416
|
+
}
|
|
417
|
+
return expected_transcription_options
|
|
418
|
+
|
|
419
|
+
@staticmethod
|
|
420
|
+
def verify_onnx(
|
|
421
|
+
model_name_or_path: str,
|
|
422
|
+
cache_dir: str,
|
|
423
|
+
ort_session: InferenceSession,
|
|
424
|
+
device: torch.device,
|
|
425
|
+
batch_size: int = 1,
|
|
426
|
+
prompt_mode: bool = False,
|
|
427
|
+
):
|
|
428
|
+
"""Compare the result from PyTorch and ONNX Runtime to verify the ONNX model is good."""
|
|
429
|
+
extra_kwargs = {}
|
|
430
|
+
if version.parse(transformers_version) >= version.parse("4.36.0"):
|
|
431
|
+
extra_kwargs["attn_implementation"] = "eager"
|
|
432
|
+
pt_model = WhisperForConditionalGeneration.from_pretrained(
|
|
433
|
+
model_name_or_path, cache_dir=cache_dir, **extra_kwargs
|
|
434
|
+
).to(device)
|
|
435
|
+
processor = WhisperProcessor.from_pretrained(model_name_or_path, cache_dir=cache_dir)
|
|
436
|
+
config = WhisperConfig.from_pretrained(model_name_or_path, cache_dir=cache_dir)
|
|
437
|
+
|
|
438
|
+
inputs, pt_transcription, pt_outputs, decoder_prompt_ids = WhisperHelper.pt_transcription_for_verify_onnx(
|
|
439
|
+
processor,
|
|
440
|
+
pt_model,
|
|
441
|
+
device,
|
|
442
|
+
batch_size=batch_size,
|
|
443
|
+
prompt_mode=prompt_mode,
|
|
444
|
+
)
|
|
445
|
+
|
|
446
|
+
start_id = [config.decoder_start_token_id] # ex: [50258]
|
|
447
|
+
prompt_ids = processor.get_decoder_prompt_ids(language="english", task="transcribe")
|
|
448
|
+
prompt_ids = list(map(lambda token: token[1], prompt_ids)) # ex: [50259, 50358, 50363]
|
|
449
|
+
forced_decoder_ids = start_id + prompt_ids # ex: [50258, 50259, 50358, 50363]
|
|
450
|
+
|
|
451
|
+
ort_names = list(map(lambda entry: entry.name, ort_session.get_inputs()))
|
|
452
|
+
ort_dtypes = list(map(lambda entry: entry.type, ort_session.get_inputs()))
|
|
453
|
+
ort_to_np = {
|
|
454
|
+
"tensor(float)": np.float32,
|
|
455
|
+
"tensor(float16)": np.float16,
|
|
456
|
+
"tensor(int64)": np.int64,
|
|
457
|
+
"tensor(int32)": np.int32,
|
|
458
|
+
"tensor(int8)": np.int8,
|
|
459
|
+
"tensor(uint8)": np.uint8,
|
|
460
|
+
}
|
|
461
|
+
|
|
462
|
+
use_extra_decoding_ids = "extra_decoding_ids" in ort_names
|
|
463
|
+
for name, dtype in zip(ort_names, ort_dtypes):
|
|
464
|
+
if name == "input_features":
|
|
465
|
+
inputs[name] = inputs[name].detach().cpu().numpy()
|
|
466
|
+
elif name == "vocab_mask":
|
|
467
|
+
inputs[name] = np.ones(config.vocab_size, dtype=ort_to_np[dtype])
|
|
468
|
+
elif name == "prefix_vocab_mask":
|
|
469
|
+
inputs[name] = np.ones((batch_size, config.vocab_size), dtype=ort_to_np[dtype])
|
|
470
|
+
elif name == "decoder_input_ids":
|
|
471
|
+
if not prompt_mode:
|
|
472
|
+
raw_input_ids = [start_id] if use_extra_decoding_ids else [forced_decoder_ids]
|
|
473
|
+
inputs[name] = np.array(raw_input_ids, dtype=ort_to_np[dtype])
|
|
474
|
+
else:
|
|
475
|
+
# This logic handles the scenario for when prompts are not of the same size
|
|
476
|
+
# For example if our prompt ids are [p1_id_1, p1_id_2] and [p2_id_1]
|
|
477
|
+
# The final decoder_input_ids will look as such after padding
|
|
478
|
+
# [prev_token, p1_id_1, p1_id_2, start_token, lang_token, transcribe_token]
|
|
479
|
+
# [prev_token, p2_id_1, PAD_TOKEN, start_token, lang_token, transcribe_token]
|
|
480
|
+
ort_prompts = []
|
|
481
|
+
for i in range(batch_size):
|
|
482
|
+
ort_prompts.append(decoder_prompt_ids[i].tolist())
|
|
483
|
+
max_len = max(len(p) for p in ort_prompts)
|
|
484
|
+
padded_prompts = []
|
|
485
|
+
for p in ort_prompts:
|
|
486
|
+
padded_prompt = [*p, *([config.pad_token_id] * (max_len - len(p)))]
|
|
487
|
+
padded_prompts.append(padded_prompt + forced_decoder_ids)
|
|
488
|
+
inputs[name] = np.array(padded_prompts, dtype=ort_to_np[dtype])
|
|
489
|
+
elif name == "logits_processor":
|
|
490
|
+
inputs[name] = np.array([1], dtype=ort_to_np[dtype])
|
|
491
|
+
elif name == "cross_qk_layer_head":
|
|
492
|
+
inputs[name] = np.array([[0, 0]], dtype=ort_to_np[dtype])
|
|
493
|
+
elif name == "extra_decoding_ids":
|
|
494
|
+
inputs[name] = np.repeat(np.array([prompt_ids], dtype=ort_to_np[dtype]), batch_size, 0)
|
|
495
|
+
elif name == "temperature":
|
|
496
|
+
inputs[name] = np.array([1.0], dtype=ort_to_np[dtype])
|
|
497
|
+
else:
|
|
498
|
+
inputs[name] = np.array([inputs[name]], dtype=ort_to_np[dtype])
|
|
499
|
+
ort_outputs = ort_session.run(None, inputs)[0][:, 0, :]
|
|
500
|
+
ort_transcription = processor.batch_decode(ort_outputs, skip_special_tokens=True)
|
|
501
|
+
expected_transcription_options = WhisperHelper.select_transcription_options(batch_size, prompt_mode)
|
|
502
|
+
|
|
503
|
+
parity = 1
|
|
504
|
+
for i in range(batch_size):
|
|
505
|
+
parity *= (
|
|
506
|
+
pt_transcription[i] in expected_transcription_options
|
|
507
|
+
and ort_transcription[i] in expected_transcription_options
|
|
508
|
+
)
|
|
509
|
+
max_diff = 0
|
|
510
|
+
|
|
511
|
+
if not parity:
|
|
512
|
+
for i in range(batch_size):
|
|
513
|
+
if pt_outputs[i].shape != ort_outputs[i].shape:
|
|
514
|
+
diff = pt_outputs[i] - ort_outputs[i][:, : len(pt_outputs[i])]
|
|
515
|
+
else:
|
|
516
|
+
diff = pt_outputs[i] - ort_outputs[i]
|
|
517
|
+
max_diff_i = max(diff.min(), diff.max(), key=abs)
|
|
518
|
+
max_diff = max(max_diff, max_diff_i)
|
|
519
|
+
|
|
520
|
+
if max_diff != 0:
|
|
521
|
+
logger.warning(f"PyTorch outputs: {pt_transcription}")
|
|
522
|
+
logger.warning(f"ONNX Runtime outputs: {ort_transcription}")
|
|
523
|
+
|
|
524
|
+
return max_diff
|
|
@@ -0,0 +1,84 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License. See License.txt in the project root for
|
|
4
|
+
# license information.
|
|
5
|
+
# --------------------------------------------------------------------------
|
|
6
|
+
|
|
7
|
+
import logging
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
|
|
11
|
+
logger = logging.getLogger(__name__)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class WhisperDecoderInitOpenai(torch.nn.Module):
|
|
15
|
+
"""WhisperDecoderInit for Openai."""
|
|
16
|
+
|
|
17
|
+
def __init__(
|
|
18
|
+
self,
|
|
19
|
+
model: torch.nn.Module,
|
|
20
|
+
decoder: torch.nn.Module,
|
|
21
|
+
):
|
|
22
|
+
super().__init__()
|
|
23
|
+
self.whisper_model = model
|
|
24
|
+
self.whisper_decoder = decoder
|
|
25
|
+
self.kv_cache = {}
|
|
26
|
+
|
|
27
|
+
@torch.no_grad()
|
|
28
|
+
def forward(
|
|
29
|
+
self,
|
|
30
|
+
tokens,
|
|
31
|
+
audio_features,
|
|
32
|
+
past=None,
|
|
33
|
+
remove_hooks=False,
|
|
34
|
+
):
|
|
35
|
+
# Create a kv_cache for past_values
|
|
36
|
+
past_kv_cache = dict()
|
|
37
|
+
if past is not None:
|
|
38
|
+
# Convert past values from 4D to 3D
|
|
39
|
+
past = [torch.transpose(val, 1, 2) for val in past]
|
|
40
|
+
past = [val.reshape(val.shape[:2] + (-1,)) for val in past]
|
|
41
|
+
half_idx = len(past) // 2
|
|
42
|
+
for idx, block in enumerate(self.whisper_decoder.blocks):
|
|
43
|
+
past_kv_cache[block.attn.key] = past[2 * idx]
|
|
44
|
+
past_kv_cache[block.attn.value] = past[2 * idx + 1]
|
|
45
|
+
past_kv_cache[block.cross_attn.key] = past[2 * idx + half_idx]
|
|
46
|
+
past_kv_cache[block.cross_attn.value] = past[2 * idx + half_idx + 1]
|
|
47
|
+
|
|
48
|
+
hooks = None
|
|
49
|
+
if not self.kv_cache:
|
|
50
|
+
self.kv_cache, hooks = self.whisper_model.install_kv_cache_hooks()
|
|
51
|
+
|
|
52
|
+
logits = self.whisper_decoder(tokens, audio_features, kv_cache=past_kv_cache)
|
|
53
|
+
|
|
54
|
+
# Add concat node for past values
|
|
55
|
+
if past is not None:
|
|
56
|
+
for block in self.whisper_decoder.blocks:
|
|
57
|
+
self.kv_cache[block.attn.key] = torch.cat(
|
|
58
|
+
[past_kv_cache[block.attn.key], self.kv_cache[block.attn.key]], dim=1
|
|
59
|
+
).detach()
|
|
60
|
+
self.kv_cache[block.attn.value] = torch.cat(
|
|
61
|
+
[past_kv_cache[block.attn.value], self.kv_cache[block.attn.value]], dim=1
|
|
62
|
+
).detach()
|
|
63
|
+
|
|
64
|
+
present_self, present_cross = [], []
|
|
65
|
+
# Group self and cross values
|
|
66
|
+
for block in self.whisper_decoder.blocks:
|
|
67
|
+
present_self.append(self.kv_cache[block.attn.key])
|
|
68
|
+
present_self.append(self.kv_cache[block.attn.value])
|
|
69
|
+
if past is None:
|
|
70
|
+
present_cross.append(self.kv_cache[block.cross_attn.key])
|
|
71
|
+
present_cross.append(self.kv_cache[block.cross_attn.value])
|
|
72
|
+
|
|
73
|
+
present_self = present_self + present_cross
|
|
74
|
+
# Add reshape and transpose ops to convert from 3D to 4D
|
|
75
|
+
present_self = [
|
|
76
|
+
present_val.reshape(present_val.shape[:2] + (-1, 64)).transpose(1, 2) for present_val in present_self
|
|
77
|
+
]
|
|
78
|
+
|
|
79
|
+
# Remove forward hooks to avoid model cloning step
|
|
80
|
+
if hooks is not None and remove_hooks:
|
|
81
|
+
self.kv_cache = {}
|
|
82
|
+
for hook in hooks:
|
|
83
|
+
hook.remove()
|
|
84
|
+
return logits, present_self
|