onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxruntime/LICENSE +21 -0
- onnxruntime/Privacy.md +21 -0
- onnxruntime/ThirdPartyNotices.txt +6121 -0
- onnxruntime/__init__.py +418 -0
- onnxruntime/backend/__init__.py +6 -0
- onnxruntime/backend/backend.py +175 -0
- onnxruntime/backend/backend_rep.py +52 -0
- onnxruntime/capi/DirectML.dll +0 -0
- onnxruntime/capi/__init__.py +4 -0
- onnxruntime/capi/_ld_preload.py +7 -0
- onnxruntime/capi/_pybind_state.py +33 -0
- onnxruntime/capi/build_and_package_info.py +2 -0
- onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
- onnxruntime/capi/onnxruntime.dll +0 -0
- onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
- onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
- onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
- onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
- onnxruntime/capi/onnxruntime_validation.py +154 -0
- onnxruntime/capi/version_info.py +2 -0
- onnxruntime/datasets/__init__.py +18 -0
- onnxruntime/datasets/logreg_iris.onnx +0 -0
- onnxruntime/datasets/mul_1.onnx +0 -0
- onnxruntime/datasets/sigmoid.onnx +13 -0
- onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
- onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
- onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
- onnxruntime/quantization/__init__.py +19 -0
- onnxruntime/quantization/base_quantizer.py +529 -0
- onnxruntime/quantization/calibrate.py +1267 -0
- onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
- onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
- onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
- onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
- onnxruntime/quantization/fusions/__init__.py +4 -0
- onnxruntime/quantization/fusions/fusion.py +311 -0
- onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
- onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
- onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
- onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
- onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
- onnxruntime/quantization/neural_compressor/__init__.py +1 -0
- onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
- onnxruntime/quantization/neural_compressor/util.py +80 -0
- onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
- onnxruntime/quantization/onnx_model.py +600 -0
- onnxruntime/quantization/onnx_quantizer.py +1163 -0
- onnxruntime/quantization/operators/__init__.py +2 -0
- onnxruntime/quantization/operators/activation.py +119 -0
- onnxruntime/quantization/operators/argmax.py +18 -0
- onnxruntime/quantization/operators/attention.py +73 -0
- onnxruntime/quantization/operators/base_operator.py +26 -0
- onnxruntime/quantization/operators/binary_op.py +72 -0
- onnxruntime/quantization/operators/concat.py +62 -0
- onnxruntime/quantization/operators/conv.py +260 -0
- onnxruntime/quantization/operators/direct_q8.py +78 -0
- onnxruntime/quantization/operators/embed_layernorm.py +121 -0
- onnxruntime/quantization/operators/gather.py +64 -0
- onnxruntime/quantization/operators/gavgpool.py +62 -0
- onnxruntime/quantization/operators/gemm.py +172 -0
- onnxruntime/quantization/operators/lstm.py +121 -0
- onnxruntime/quantization/operators/matmul.py +231 -0
- onnxruntime/quantization/operators/maxpool.py +34 -0
- onnxruntime/quantization/operators/norm.py +40 -0
- onnxruntime/quantization/operators/pad.py +172 -0
- onnxruntime/quantization/operators/pooling.py +67 -0
- onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
- onnxruntime/quantization/operators/resize.py +34 -0
- onnxruntime/quantization/operators/softmax.py +74 -0
- onnxruntime/quantization/operators/split.py +63 -0
- onnxruntime/quantization/operators/where.py +87 -0
- onnxruntime/quantization/preprocess.py +141 -0
- onnxruntime/quantization/qdq_loss_debug.py +389 -0
- onnxruntime/quantization/qdq_quantizer.py +1477 -0
- onnxruntime/quantization/quant_utils.py +1051 -0
- onnxruntime/quantization/quantize.py +953 -0
- onnxruntime/quantization/registry.py +110 -0
- onnxruntime/quantization/shape_inference.py +204 -0
- onnxruntime/quantization/static_quantize_runner.py +256 -0
- onnxruntime/quantization/tensor_quant_overrides.py +520 -0
- onnxruntime/tools/__init__.py +10 -0
- onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
- onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
- onnxruntime/tools/file_utils.py +47 -0
- onnxruntime/tools/logger.py +11 -0
- onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
- onnxruntime/tools/mobile_helpers/__init__.py +0 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
- onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
- onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
- onnxruntime/tools/offline_tuning.py +169 -0
- onnxruntime/tools/onnx_model_utils.py +416 -0
- onnxruntime/tools/onnx_randomizer.py +85 -0
- onnxruntime/tools/onnxruntime_test.py +164 -0
- onnxruntime/tools/optimize_onnx_model.py +56 -0
- onnxruntime/tools/ort_format_model/__init__.py +27 -0
- onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
- onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
- onnxruntime/tools/ort_format_model/types.py +85 -0
- onnxruntime/tools/ort_format_model/utils.py +61 -0
- onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
- onnxruntime/tools/pytorch_export_helpers.py +131 -0
- onnxruntime/tools/qdq_helpers/__init__.py +0 -0
- onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
- onnxruntime/tools/qnn/add_trans_cast.py +292 -0
- onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
- onnxruntime/tools/qnn/preprocess.py +165 -0
- onnxruntime/tools/reduced_build_config_parser.py +203 -0
- onnxruntime/tools/remove_initializer_from_input.py +37 -0
- onnxruntime/tools/symbolic_shape_infer.py +3094 -0
- onnxruntime/tools/update_onnx_opset.py +31 -0
- onnxruntime/transformers/__init__.py +8 -0
- onnxruntime/transformers/affinity_helper.py +40 -0
- onnxruntime/transformers/benchmark.py +942 -0
- onnxruntime/transformers/benchmark_helper.py +643 -0
- onnxruntime/transformers/bert_perf_test.py +629 -0
- onnxruntime/transformers/bert_test_data.py +641 -0
- onnxruntime/transformers/compare_bert_results.py +256 -0
- onnxruntime/transformers/constants.py +47 -0
- onnxruntime/transformers/convert_generation.py +3605 -0
- onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
- onnxruntime/transformers/convert_to_packing_mode.py +385 -0
- onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
- onnxruntime/transformers/float16.py +501 -0
- onnxruntime/transformers/fusion_attention.py +1189 -0
- onnxruntime/transformers/fusion_attention_clip.py +340 -0
- onnxruntime/transformers/fusion_attention_sam2.py +533 -0
- onnxruntime/transformers/fusion_attention_unet.py +1307 -0
- onnxruntime/transformers/fusion_attention_vae.py +300 -0
- onnxruntime/transformers/fusion_bart_attention.py +435 -0
- onnxruntime/transformers/fusion_base.py +141 -0
- onnxruntime/transformers/fusion_bias_add.py +57 -0
- onnxruntime/transformers/fusion_biasgelu.py +66 -0
- onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
- onnxruntime/transformers/fusion_conformer_attention.py +222 -0
- onnxruntime/transformers/fusion_constant_fold.py +144 -0
- onnxruntime/transformers/fusion_embedlayer.py +810 -0
- onnxruntime/transformers/fusion_fastgelu.py +492 -0
- onnxruntime/transformers/fusion_gelu.py +258 -0
- onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
- onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
- onnxruntime/transformers/fusion_gpt_attention.py +546 -0
- onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
- onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
- onnxruntime/transformers/fusion_group_norm.py +180 -0
- onnxruntime/transformers/fusion_layernorm.py +489 -0
- onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
- onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
- onnxruntime/transformers/fusion_options.py +340 -0
- onnxruntime/transformers/fusion_qordered_attention.py +420 -0
- onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
- onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
- onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
- onnxruntime/transformers/fusion_quickgelu.py +74 -0
- onnxruntime/transformers/fusion_reshape.py +173 -0
- onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
- onnxruntime/transformers/fusion_shape.py +109 -0
- onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
- onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
- onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
- onnxruntime/transformers/fusion_transpose.py +167 -0
- onnxruntime/transformers/fusion_utils.py +321 -0
- onnxruntime/transformers/huggingface_models.py +74 -0
- onnxruntime/transformers/import_utils.py +20 -0
- onnxruntime/transformers/io_binding_helper.py +487 -0
- onnxruntime/transformers/large_model_exporter.py +395 -0
- onnxruntime/transformers/machine_info.py +230 -0
- onnxruntime/transformers/metrics.py +163 -0
- onnxruntime/transformers/models/bart/__init__.py +12 -0
- onnxruntime/transformers/models/bart/export.py +98 -0
- onnxruntime/transformers/models/bert/__init__.py +12 -0
- onnxruntime/transformers/models/bert/eval_squad.py +329 -0
- onnxruntime/transformers/models/gpt2/__init__.py +12 -0
- onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
- onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
- onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
- onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
- onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
- onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
- onnxruntime/transformers/models/llama/__init__.py +12 -0
- onnxruntime/transformers/models/llama/benchmark.py +700 -0
- onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
- onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
- onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
- onnxruntime/transformers/models/llama/dist_settings.py +57 -0
- onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
- onnxruntime/transformers/models/llama/llama_parity.py +343 -0
- onnxruntime/transformers/models/llama/llama_torch.py +47 -0
- onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
- onnxruntime/transformers/models/longformer/__init__.py +12 -0
- onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
- onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
- onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
- onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
- onnxruntime/transformers/models/phi2/__init__.py +12 -0
- onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
- onnxruntime/transformers/models/phi2/inference_example.py +414 -0
- onnxruntime/transformers/models/sam2/__init__.py +12 -0
- onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
- onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
- onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
- onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
- onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
- onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
- onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
- onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
- onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
- onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
- onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
- onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
- onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
- onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
- onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
- onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
- onnxruntime/transformers/models/t5/__init__.py +12 -0
- onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
- onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
- onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
- onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
- onnxruntime/transformers/models/t5/t5_helper.py +302 -0
- onnxruntime/transformers/models/whisper/__init__.py +12 -0
- onnxruntime/transformers/models/whisper/benchmark.py +585 -0
- onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
- onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
- onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
- onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
- onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
- onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
- onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
- onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
- onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
- onnxruntime/transformers/onnx_exporter.py +719 -0
- onnxruntime/transformers/onnx_model.py +1636 -0
- onnxruntime/transformers/onnx_model_bart.py +141 -0
- onnxruntime/transformers/onnx_model_bert.py +488 -0
- onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
- onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
- onnxruntime/transformers/onnx_model_clip.py +42 -0
- onnxruntime/transformers/onnx_model_conformer.py +32 -0
- onnxruntime/transformers/onnx_model_gpt2.py +101 -0
- onnxruntime/transformers/onnx_model_mmdit.py +112 -0
- onnxruntime/transformers/onnx_model_phi.py +929 -0
- onnxruntime/transformers/onnx_model_sam2.py +137 -0
- onnxruntime/transformers/onnx_model_t5.py +985 -0
- onnxruntime/transformers/onnx_model_tnlr.py +226 -0
- onnxruntime/transformers/onnx_model_unet.py +258 -0
- onnxruntime/transformers/onnx_model_vae.py +42 -0
- onnxruntime/transformers/onnx_utils.py +55 -0
- onnxruntime/transformers/optimizer.py +620 -0
- onnxruntime/transformers/past_helper.py +149 -0
- onnxruntime/transformers/profile_result_processor.py +358 -0
- onnxruntime/transformers/profiler.py +434 -0
- onnxruntime/transformers/quantize_helper.py +76 -0
- onnxruntime/transformers/shape_infer_helper.py +121 -0
- onnxruntime/transformers/shape_optimizer.py +400 -0
- onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
- onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
- onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
- onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
- onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
- onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,464 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License. See License.txt in the project root for
|
|
4
|
+
# license information.
|
|
5
|
+
# --------------------------------------------------------------------------
|
|
6
|
+
|
|
7
|
+
import logging
|
|
8
|
+
import os
|
|
9
|
+
import tempfile
|
|
10
|
+
from itertools import chain
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
|
|
13
|
+
import numpy as np
|
|
14
|
+
import onnx
|
|
15
|
+
import torch
|
|
16
|
+
from float16 import convert_float_to_float16
|
|
17
|
+
from google.protobuf.internal.containers import RepeatedCompositeFieldContainer
|
|
18
|
+
from onnx import ModelProto, ValueInfoProto
|
|
19
|
+
from onnx_model import OnnxModel
|
|
20
|
+
from past_helper import PastKeyValuesHelper
|
|
21
|
+
from transformers import WhisperConfig
|
|
22
|
+
from whisper_inputs import (
|
|
23
|
+
convert_inputs_for_ort,
|
|
24
|
+
get_model_dynamic_axes,
|
|
25
|
+
get_sample_decoder_inputs,
|
|
26
|
+
group_past_key_values,
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
from onnxruntime import InferenceSession
|
|
30
|
+
|
|
31
|
+
logger = logging.getLogger(__name__)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class WhisperDecoder(torch.nn.Module):
|
|
35
|
+
"""A Whisper decoder with optional past key values"""
|
|
36
|
+
|
|
37
|
+
def __init__(self, config: WhisperConfig, model: torch.nn.Module, model_impl: str, no_beam_search_op: bool = False):
|
|
38
|
+
super().__init__()
|
|
39
|
+
self.config = config
|
|
40
|
+
self.device = model.device
|
|
41
|
+
self.model_impl = model_impl
|
|
42
|
+
self.no_beam_search_op = no_beam_search_op
|
|
43
|
+
|
|
44
|
+
self.decoder = None if model_impl == "openai" else model.model.decoder
|
|
45
|
+
self.proj_out = None if model_impl == "openai" else model.proj_out
|
|
46
|
+
self.model = model if model_impl == "openai" else None
|
|
47
|
+
|
|
48
|
+
self.max_source_positions = self.config.max_source_positions
|
|
49
|
+
self.num_heads = self.config.decoder_attention_heads
|
|
50
|
+
self.head_size = self.config.d_model // self.num_heads
|
|
51
|
+
|
|
52
|
+
def hf_forward(
|
|
53
|
+
self,
|
|
54
|
+
decoder_input_ids: torch.Tensor,
|
|
55
|
+
encoder_hidden_states: torch.Tensor | None = None,
|
|
56
|
+
past_key_values: list[tuple[torch.Tensor]] | None = None,
|
|
57
|
+
):
|
|
58
|
+
outputs = self.decoder(
|
|
59
|
+
encoder_hidden_states=encoder_hidden_states,
|
|
60
|
+
input_ids=decoder_input_ids,
|
|
61
|
+
past_key_values=past_key_values,
|
|
62
|
+
use_cache=True,
|
|
63
|
+
)
|
|
64
|
+
logits = self.proj_out(outputs.last_hidden_state)
|
|
65
|
+
present_key_values = outputs.past_key_values
|
|
66
|
+
|
|
67
|
+
if past_key_values is None:
|
|
68
|
+
# Return present_self_* and present_cross_* for decoder-init
|
|
69
|
+
return logits, present_key_values
|
|
70
|
+
|
|
71
|
+
# Before: (past_key_self_0, past_value_self_0, past_key_cross_0, past_value_cross_0),
|
|
72
|
+
# (past_key_self_1, past_value_self_1, past_key_cross_1, past_value_cross_1),
|
|
73
|
+
# After: (past_key_self_0, past_value_self_0, past_key_self_1, past_value_self_1), ...,
|
|
74
|
+
# (past_key_cross_0, past_value_cross_0, past_key_cross_1, past_value_cross_1), ...
|
|
75
|
+
present_self, present_cross = PastKeyValuesHelper.group_by_self_and_cross(present_key_values)
|
|
76
|
+
|
|
77
|
+
# Return present_self_* for decoder-with-past since past_cross_* and present_cross_* are identical
|
|
78
|
+
return logits, present_self
|
|
79
|
+
|
|
80
|
+
def oai_forward(
|
|
81
|
+
self,
|
|
82
|
+
decoder_input_ids: torch.Tensor,
|
|
83
|
+
encoder_hidden_states: torch.Tensor | None = None,
|
|
84
|
+
past_key_values: list[tuple[torch.Tensor]] | None = None,
|
|
85
|
+
):
|
|
86
|
+
past_kv_cache = {}
|
|
87
|
+
if past_key_values is not None:
|
|
88
|
+
# Convert past KV caches (BxNxSxH --> BxSxNxH --> BxSxD) for OpenAI's forward pass
|
|
89
|
+
self_attn_kv_caches, cross_attn_kv_caches = group_past_key_values(past_key_values)
|
|
90
|
+
self_attn_kv_caches = [past_kv.transpose(1, 2) for past_kv in self_attn_kv_caches]
|
|
91
|
+
self_attn_kv_caches = [past_kv.reshape((*past_kv.shape[:2], -1)) for past_kv in self_attn_kv_caches]
|
|
92
|
+
cross_attn_kv_caches = [past_kv.transpose(1, 2) for past_kv in cross_attn_kv_caches]
|
|
93
|
+
cross_attn_kv_caches = [past_kv.reshape((*past_kv.shape[:2], -1)) for past_kv in cross_attn_kv_caches]
|
|
94
|
+
|
|
95
|
+
for idx, block in enumerate(self.model.decoder.blocks):
|
|
96
|
+
past_kv_cache[block.attn.key] = self_attn_kv_caches[2 * idx]
|
|
97
|
+
past_kv_cache[block.attn.value] = self_attn_kv_caches[2 * idx + 1]
|
|
98
|
+
past_kv_cache[block.cross_attn.key] = cross_attn_kv_caches[2 * idx]
|
|
99
|
+
past_kv_cache[block.cross_attn.value] = cross_attn_kv_caches[2 * idx + 1]
|
|
100
|
+
|
|
101
|
+
# Install OpenAI's hooks on the forward pass of each nn.Linear for key and value
|
|
102
|
+
# since the hooks will capture the output of the key and value MatMuls, which
|
|
103
|
+
# represent the current keys and values.
|
|
104
|
+
#
|
|
105
|
+
# For OpenAI's forward pass, the hook function will also perform the concat
|
|
106
|
+
# operation (past_kv + curr_kv --> pres_kv) if needed. However, the ONNX model
|
|
107
|
+
# will not contain this concat operation because the present KV caches aren't
|
|
108
|
+
# returned by OpenAI's forward pass.
|
|
109
|
+
kv_cache, hooks = self.model.install_kv_cache_hooks()
|
|
110
|
+
|
|
111
|
+
# Run forward pass
|
|
112
|
+
# NOTE: There is a bug with openai-whisper==20240930 with the introduction of SDPA.
|
|
113
|
+
# In the Whisper codebase, the following line
|
|
114
|
+
#
|
|
115
|
+
# is_causal = mask is not None and n_ctx > 1
|
|
116
|
+
#
|
|
117
|
+
# has been added where `mask` is a torch tensor. The right-hand side evaluates to `tensor(True/False)`
|
|
118
|
+
# but `is_causal` only accepts the boolean value. The fix is to apply `.item()` after the right-hand
|
|
119
|
+
# side has been evaluated. In other words, the line should be
|
|
120
|
+
#
|
|
121
|
+
# is_causal = (mask is not None and n_ctx > 1).item()
|
|
122
|
+
#
|
|
123
|
+
# instead.
|
|
124
|
+
logits = self.model.decoder(x=decoder_input_ids, xa=encoder_hidden_states, kv_cache=past_kv_cache)
|
|
125
|
+
|
|
126
|
+
# Re-do concat operation on self attention KV caches for ONNX export (if past self attention KV caches exist)
|
|
127
|
+
if past_key_values is not None:
|
|
128
|
+
for block in self.model.decoder.blocks:
|
|
129
|
+
kv_cache[block.attn.key] = torch.cat(
|
|
130
|
+
[past_kv_cache[block.attn.key], kv_cache[block.attn.key]], dim=1
|
|
131
|
+
).detach()
|
|
132
|
+
kv_cache[block.attn.value] = torch.cat(
|
|
133
|
+
[past_kv_cache[block.attn.value], kv_cache[block.attn.value]], dim=1
|
|
134
|
+
).detach()
|
|
135
|
+
|
|
136
|
+
present_self, present_cross = [], []
|
|
137
|
+
for block in self.model.decoder.blocks:
|
|
138
|
+
# Group self and cross values
|
|
139
|
+
present_self.append(kv_cache[block.attn.key])
|
|
140
|
+
present_self.append(kv_cache[block.attn.value])
|
|
141
|
+
if past_key_values is None:
|
|
142
|
+
# Return present_self_* and present_cross_* for decoder-init
|
|
143
|
+
present_cross.append(kv_cache[block.cross_attn.key])
|
|
144
|
+
present_cross.append(kv_cache[block.cross_attn.value])
|
|
145
|
+
|
|
146
|
+
# Convert present KV caches (BxSxD --> BxSxNxH --> BxNxSxH) after OpenAI's forward pass
|
|
147
|
+
present_self = [
|
|
148
|
+
present_kv.reshape((*present_kv.shape[:2], -1, self.head_size)).transpose(1, 2)
|
|
149
|
+
for present_kv in present_self
|
|
150
|
+
]
|
|
151
|
+
present_cross = [
|
|
152
|
+
present_kv.reshape((*present_kv.shape[:2], -1, self.head_size)).transpose(1, 2)
|
|
153
|
+
for present_kv in present_cross
|
|
154
|
+
]
|
|
155
|
+
|
|
156
|
+
# Remove OpenAI's hooks since they can persist after this function completes
|
|
157
|
+
for hook in hooks:
|
|
158
|
+
hook.remove()
|
|
159
|
+
|
|
160
|
+
if past_key_values is None:
|
|
161
|
+
# Return present_self_* and present_cross_* for decoder-init
|
|
162
|
+
present_key_values = PastKeyValuesHelper.group_by_layer(
|
|
163
|
+
present_self + present_cross, len(present_self) // 2
|
|
164
|
+
)
|
|
165
|
+
return logits, present_key_values
|
|
166
|
+
|
|
167
|
+
# Return present_self_* for decoder-with-past since past_cross_* and present_cross_* are identical
|
|
168
|
+
return logits, present_self
|
|
169
|
+
|
|
170
|
+
def forward(
|
|
171
|
+
self,
|
|
172
|
+
decoder_input_ids: torch.Tensor,
|
|
173
|
+
encoder_hidden_states: torch.Tensor | None = None,
|
|
174
|
+
past_key_values: list[tuple[torch.Tensor]] | None = None,
|
|
175
|
+
):
|
|
176
|
+
if self.model_impl == "openai":
|
|
177
|
+
return self.oai_forward(decoder_input_ids, encoder_hidden_states, past_key_values)
|
|
178
|
+
return self.hf_forward(decoder_input_ids, encoder_hidden_states, past_key_values)
|
|
179
|
+
|
|
180
|
+
def input_names(self):
|
|
181
|
+
if self.first_pass:
|
|
182
|
+
input_names = ["input_ids", "encoder_hidden_states"]
|
|
183
|
+
else:
|
|
184
|
+
input_names = [
|
|
185
|
+
"input_ids",
|
|
186
|
+
"encoder_hidden_states",
|
|
187
|
+
*list(
|
|
188
|
+
chain.from_iterable(
|
|
189
|
+
(f"past_key_self_{i}", f"past_value_self_{i}", f"past_key_cross_{i}", f"past_value_cross_{i}")
|
|
190
|
+
for i in range(self.config.decoder_layers)
|
|
191
|
+
)
|
|
192
|
+
),
|
|
193
|
+
]
|
|
194
|
+
return input_names
|
|
195
|
+
|
|
196
|
+
def output_names(self):
|
|
197
|
+
if self.first_pass:
|
|
198
|
+
output_names = [
|
|
199
|
+
"logits",
|
|
200
|
+
*list(
|
|
201
|
+
chain.from_iterable(
|
|
202
|
+
(
|
|
203
|
+
f"present_key_self_{i}",
|
|
204
|
+
f"present_value_self_{i}",
|
|
205
|
+
f"present_key_cross_{i}",
|
|
206
|
+
f"present_value_cross_{i}",
|
|
207
|
+
)
|
|
208
|
+
for i in range(self.config.decoder_layers)
|
|
209
|
+
)
|
|
210
|
+
),
|
|
211
|
+
]
|
|
212
|
+
else:
|
|
213
|
+
output_names = [
|
|
214
|
+
"logits",
|
|
215
|
+
*list(
|
|
216
|
+
chain.from_iterable(
|
|
217
|
+
(f"present_key_self_{i}", f"present_value_self_{i}") for i in range(self.config.decoder_layers)
|
|
218
|
+
)
|
|
219
|
+
),
|
|
220
|
+
]
|
|
221
|
+
return output_names
|
|
222
|
+
|
|
223
|
+
def dynamic_axes(self, input_names, output_names):
|
|
224
|
+
dynamic_axes = get_model_dynamic_axes(self.config, input_names, output_names)
|
|
225
|
+
if "input_ids" in dynamic_axes and not self.no_beam_search_op:
|
|
226
|
+
# Set dynamic axes for `input_ids` when using beam search op to {0: "batch_size"} only
|
|
227
|
+
del dynamic_axes["input_ids"][1]
|
|
228
|
+
return dynamic_axes
|
|
229
|
+
|
|
230
|
+
def inputs(self, use_fp16_inputs: bool, use_int32_inputs: bool, return_dict: bool = False):
|
|
231
|
+
inputs = get_sample_decoder_inputs(
|
|
232
|
+
self.config,
|
|
233
|
+
self.device,
|
|
234
|
+
batch_size=2,
|
|
235
|
+
past_sequence_length=(0 if self.first_pass else 6),
|
|
236
|
+
sequence_length=(6 if self.first_pass else 1),
|
|
237
|
+
use_fp16=use_fp16_inputs,
|
|
238
|
+
use_int32=use_int32_inputs,
|
|
239
|
+
)
|
|
240
|
+
if return_dict:
|
|
241
|
+
if self.first_pass:
|
|
242
|
+
del inputs["past_key_values"]
|
|
243
|
+
return inputs
|
|
244
|
+
|
|
245
|
+
if self.first_pass:
|
|
246
|
+
return (
|
|
247
|
+
inputs["decoder_input_ids"],
|
|
248
|
+
inputs["encoder_hidden_states"],
|
|
249
|
+
)
|
|
250
|
+
return (
|
|
251
|
+
inputs["decoder_input_ids"],
|
|
252
|
+
inputs["encoder_hidden_states"],
|
|
253
|
+
inputs["past_key_values"],
|
|
254
|
+
)
|
|
255
|
+
|
|
256
|
+
def fix_key_value_cache_dims(self, io: ValueInfoProto, is_cross: bool = False, is_output: bool = False):
|
|
257
|
+
# Shape should be (batch_size, num_heads, sequence_length, head_size) for self attention KV caches
|
|
258
|
+
# and (batch_size, num_heads, num_frames // 2, head_size) for cross attention KV caches
|
|
259
|
+
num_heads = io.type.tensor_type.shape.dim[1]
|
|
260
|
+
if "_dim_" in num_heads.dim_param:
|
|
261
|
+
num_heads.Clear()
|
|
262
|
+
num_heads.dim_value = self.num_heads
|
|
263
|
+
sequence_length = io.type.tensor_type.shape.dim[2]
|
|
264
|
+
if "_dim_" in sequence_length.dim_param:
|
|
265
|
+
sequence_length.Clear()
|
|
266
|
+
if is_cross:
|
|
267
|
+
sequence_length.dim_value = self.max_source_positions
|
|
268
|
+
else:
|
|
269
|
+
sequence_length.dim_param = "total_sequence_length" if is_output else "past_sequence_length"
|
|
270
|
+
head_size = io.type.tensor_type.shape.dim[3]
|
|
271
|
+
if "_dim_" in head_size.dim_param:
|
|
272
|
+
head_size.Clear()
|
|
273
|
+
head_size.dim_value = self.head_size
|
|
274
|
+
return io
|
|
275
|
+
|
|
276
|
+
def fix_io(self, io_list: RepeatedCompositeFieldContainer, is_output: bool = False):
|
|
277
|
+
# Fix order of inputs/outputs and each dim_value of input/output
|
|
278
|
+
reordered_io = []
|
|
279
|
+
self_attn_kv_caches = []
|
|
280
|
+
cross_attn_kv_caches = []
|
|
281
|
+
|
|
282
|
+
for io in io_list:
|
|
283
|
+
if "past" not in io.name and "present" not in io.name:
|
|
284
|
+
reordered_io.append(io)
|
|
285
|
+
elif "self" in io.name:
|
|
286
|
+
# Self attention KV caches
|
|
287
|
+
new_io = self.fix_key_value_cache_dims(io, is_cross=False, is_output=is_output)
|
|
288
|
+
if self.no_beam_search_op:
|
|
289
|
+
reordered_io.append(new_io)
|
|
290
|
+
else:
|
|
291
|
+
self_attn_kv_caches.append(new_io)
|
|
292
|
+
else:
|
|
293
|
+
# Cross attention KV caches
|
|
294
|
+
new_io = self.fix_key_value_cache_dims(io, is_cross=True, is_output=is_output)
|
|
295
|
+
if self.no_beam_search_op:
|
|
296
|
+
reordered_io.append(new_io)
|
|
297
|
+
else:
|
|
298
|
+
cross_attn_kv_caches.append(new_io)
|
|
299
|
+
|
|
300
|
+
if not self.no_beam_search_op:
|
|
301
|
+
reordered_io += self_attn_kv_caches + cross_attn_kv_caches
|
|
302
|
+
return reordered_io
|
|
303
|
+
|
|
304
|
+
def fix_inputs_and_outputs(self, model: ModelProto):
|
|
305
|
+
# ONNX exporter might mark dimensions like 'Transposepresent_value_self_1_dim_2' in shape inference.
|
|
306
|
+
# We now change the dim_values to the correct one.
|
|
307
|
+
reordered_inputs = self.fix_io(model.graph.input, is_output=False)
|
|
308
|
+
while len(model.graph.input) > 0:
|
|
309
|
+
model.graph.input.pop()
|
|
310
|
+
model.graph.input.extend(reordered_inputs)
|
|
311
|
+
|
|
312
|
+
reordered_outputs = self.fix_io(model.graph.output, is_output=True)
|
|
313
|
+
while len(model.graph.output) > 0:
|
|
314
|
+
model.graph.output.pop()
|
|
315
|
+
model.graph.output.extend(reordered_outputs)
|
|
316
|
+
return model
|
|
317
|
+
|
|
318
|
+
def fix_layernorm_weights(self, model: ModelProto, use_fp16_inputs: bool):
|
|
319
|
+
if self.model_impl == "openai" and use_fp16_inputs:
|
|
320
|
+
# Cast ONNX model to float16 to ensure LayerNorm weights are converted from
|
|
321
|
+
# float32 to float16 since exported model already has float16 weights everywhere
|
|
322
|
+
# except for LayerNorm ops. This happens because OpenAI always upcasts to float32
|
|
323
|
+
# when computing LayerNorm.
|
|
324
|
+
#
|
|
325
|
+
# Reference:
|
|
326
|
+
# https://github.com/openai/whisper/blob/90db0de1896c23cbfaf0c58bc2d30665f709f170/whisper/model.py#L41
|
|
327
|
+
model = convert_float_to_float16(model)
|
|
328
|
+
return model
|
|
329
|
+
|
|
330
|
+
def export_onnx(
|
|
331
|
+
self,
|
|
332
|
+
onnx_model_path: str,
|
|
333
|
+
provider: str,
|
|
334
|
+
verbose: bool = True,
|
|
335
|
+
use_external_data_format: bool = False,
|
|
336
|
+
use_fp16_inputs: bool = False,
|
|
337
|
+
use_int32_inputs: bool = True,
|
|
338
|
+
use_encoder_hidden_states: bool = False,
|
|
339
|
+
use_kv_cache_inputs: bool = True,
|
|
340
|
+
):
|
|
341
|
+
"""Export decoder to ONNX
|
|
342
|
+
|
|
343
|
+
Args:
|
|
344
|
+
onnx_model_path (str): path to save ONNX model
|
|
345
|
+
provider (str): provider to use for verifying parity on ONNX model
|
|
346
|
+
verbose (bool, optional): print verbose information. Defaults to True.
|
|
347
|
+
use_external_data_format (bool, optional): use external data format or not. Defaults to False.
|
|
348
|
+
use_fp16_inputs (bool, optional): use float16 inputs for the KV caches. Defaults to False.
|
|
349
|
+
use_int32_inputs (bool, optional): use int32 inputs for the decoder_input_ids. Defaults to True.
|
|
350
|
+
use_encoder_hidden_states (bool, optional): use encoder_hidden_states as model input for decoder-init/decoder-without-past models. Defaults to False.
|
|
351
|
+
use_kv_cache_inputs (bool, optional): use KV caches as model inputs for decoder-with-past models. Defaults to True.
|
|
352
|
+
"""
|
|
353
|
+
# Shape of decoder's tensors:
|
|
354
|
+
# Required Inputs:
|
|
355
|
+
# decoder_input_ids: (batch_size, sequence_length)
|
|
356
|
+
# Optional Inputs:
|
|
357
|
+
# encoder_hidden_states (comes from encoder's outputs): (batch_size, num_frames // 2, hidden_size)
|
|
358
|
+
# past_{key/value}_self_* (past self attention KV caches): (batch_size, num_heads, past_sequence_length, head_size)
|
|
359
|
+
# past_{key/value}_cross_* (past cross attention KV caches): (batch_size, num_heads, num_frames // 2, head_size)
|
|
360
|
+
# Outputs:
|
|
361
|
+
# logits: (batch_size, sequence_length, vocab_size)
|
|
362
|
+
# present_{key/value}_self_* (present self attention KV caches): (batch_size, num_heads, past_sequence_length + sequence_length, head_size)
|
|
363
|
+
# present_{key/value}_cross_* (present cross attention KV caches): (batch_size, num_heads, num_frames // 2, head_size)
|
|
364
|
+
|
|
365
|
+
# For the first pass through the decoder (i.e. decoder-init/decoder-without-past)
|
|
366
|
+
self.first_pass = use_encoder_hidden_states and not use_kv_cache_inputs
|
|
367
|
+
|
|
368
|
+
# For subsequent passes through the decoder (i.e. decoder-with-past)
|
|
369
|
+
self.later_pass = not use_encoder_hidden_states and use_kv_cache_inputs
|
|
370
|
+
|
|
371
|
+
assert self.first_pass or self.later_pass, (
|
|
372
|
+
"Only one of `use_encoder_hidden_states` and `use_kv_cache_inputs` can be true at once."
|
|
373
|
+
)
|
|
374
|
+
|
|
375
|
+
inputs = self.inputs(use_fp16_inputs=use_fp16_inputs, use_int32_inputs=use_int32_inputs)
|
|
376
|
+
input_names = self.input_names()
|
|
377
|
+
output_names = self.output_names()
|
|
378
|
+
dynamic_axes = self.dynamic_axes(input_names, output_names)
|
|
379
|
+
|
|
380
|
+
Path(onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
|
|
381
|
+
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
|
382
|
+
temp_onnx_model_path = os.path.join(tmp_dir_name, "decoder.onnx")
|
|
383
|
+
Path(temp_onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
|
|
384
|
+
out_path = temp_onnx_model_path if use_external_data_format else onnx_model_path
|
|
385
|
+
|
|
386
|
+
torch.onnx.export(
|
|
387
|
+
self,
|
|
388
|
+
args=inputs,
|
|
389
|
+
f=out_path,
|
|
390
|
+
export_params=True,
|
|
391
|
+
input_names=input_names,
|
|
392
|
+
output_names=output_names,
|
|
393
|
+
dynamic_axes=dynamic_axes,
|
|
394
|
+
opset_version=17,
|
|
395
|
+
do_constant_folding=True,
|
|
396
|
+
verbose=verbose,
|
|
397
|
+
)
|
|
398
|
+
|
|
399
|
+
model = onnx.load_model(out_path, load_external_data=use_external_data_format)
|
|
400
|
+
model = self.fix_inputs_and_outputs(model)
|
|
401
|
+
model = self.fix_layernorm_weights(model, use_fp16_inputs)
|
|
402
|
+
OnnxModel.save(
|
|
403
|
+
model,
|
|
404
|
+
onnx_model_path,
|
|
405
|
+
save_as_external_data=use_external_data_format,
|
|
406
|
+
all_tensors_to_one_file=True,
|
|
407
|
+
)
|
|
408
|
+
|
|
409
|
+
self.verify_onnx(onnx_model_path, provider, use_fp16_inputs, use_int32_inputs)
|
|
410
|
+
|
|
411
|
+
def verify_onnx(
|
|
412
|
+
self,
|
|
413
|
+
onnx_model_path: str,
|
|
414
|
+
provider: str,
|
|
415
|
+
use_fp16_inputs: bool,
|
|
416
|
+
use_int32_inputs: bool,
|
|
417
|
+
):
|
|
418
|
+
"""Verify ONNX model outputs and PyTorch model outputs match
|
|
419
|
+
|
|
420
|
+
Args:
|
|
421
|
+
onnx_model_path (str): path to save ONNX model
|
|
422
|
+
provider (str): execution provider for ONNX model
|
|
423
|
+
use_fp16_inputs (bool, optional): use float16 inputs for the KV caches
|
|
424
|
+
use_int32_inputs (bool, optional): use int32 inputs for the decoder_input_ids
|
|
425
|
+
"""
|
|
426
|
+
# Shape of decoder's tensors:
|
|
427
|
+
# Required Inputs:
|
|
428
|
+
# decoder_input_ids: (batch_size, sequence_length)
|
|
429
|
+
# Optional Inputs:
|
|
430
|
+
# encoder_hidden_states (comes from encoder's outputs): (batch_size, num_frames // 2, hidden_size)
|
|
431
|
+
# past_{key/value}_self_* (past self attention KV caches): (batch_size, num_heads, past_sequence_length, head_size)
|
|
432
|
+
# past_{key/value}_cross_* (past cross attention KV caches): (batch_size, num_heads, num_frames // 2, head_size)
|
|
433
|
+
# Outputs:
|
|
434
|
+
# logits: (batch_size, sequence_length, vocab_size)
|
|
435
|
+
# present_{key/value}_self_* (present self attention KV caches): (batch_size, num_heads, past_sequence_length + sequence_length, head_size)
|
|
436
|
+
# present_{key/value}_cross_* (present cross attention KV caches): (batch_size, num_heads, num_frames // 2, head_size)
|
|
437
|
+
|
|
438
|
+
# Run PyTorch model
|
|
439
|
+
inputs = self.inputs(use_fp16_inputs=use_fp16_inputs, use_int32_inputs=use_int32_inputs, return_dict=True)
|
|
440
|
+
pt_outputs = []
|
|
441
|
+
if self.first_pass:
|
|
442
|
+
out = self.forward(**inputs)
|
|
443
|
+
pt_outputs.append(out[0].detach().cpu().numpy())
|
|
444
|
+
for present_key_value_layer in out[1]:
|
|
445
|
+
for present_key_value in present_key_value_layer:
|
|
446
|
+
pt_outputs.append(present_key_value.detach().cpu().numpy())
|
|
447
|
+
else:
|
|
448
|
+
out = self.forward(**inputs)
|
|
449
|
+
pt_outputs.append(out[0].detach().cpu().numpy())
|
|
450
|
+
for present_self_key_value in out[1]:
|
|
451
|
+
pt_outputs.append(present_self_key_value.detach().cpu().numpy())
|
|
452
|
+
|
|
453
|
+
# Run ONNX model
|
|
454
|
+
sess = InferenceSession(onnx_model_path, providers=[provider])
|
|
455
|
+
ort_outputs = sess.run(None, convert_inputs_for_ort(inputs, sess))
|
|
456
|
+
|
|
457
|
+
# Calculate output difference
|
|
458
|
+
try:
|
|
459
|
+
for i, output_name in enumerate(self.output_names()):
|
|
460
|
+
diff = np.abs(pt_outputs[i] - ort_outputs[i])
|
|
461
|
+
logger.warning(f"Comparing {output_name}...")
|
|
462
|
+
logger.warning(f"Max diff: {np.max(diff)}")
|
|
463
|
+
except: # noqa: E722
|
|
464
|
+
pass
|
|
@@ -0,0 +1,164 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License. See License.txt in the project root for
|
|
4
|
+
# license information.
|
|
5
|
+
# --------------------------------------------------------------------------
|
|
6
|
+
|
|
7
|
+
import logging
|
|
8
|
+
import os
|
|
9
|
+
import tempfile
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
|
|
12
|
+
import numpy as np
|
|
13
|
+
import onnx
|
|
14
|
+
import torch
|
|
15
|
+
from float16 import convert_float_to_float16
|
|
16
|
+
from onnx import ModelProto
|
|
17
|
+
from onnx_model import OnnxModel
|
|
18
|
+
from transformers import WhisperConfig
|
|
19
|
+
from whisper_inputs import get_model_dynamic_axes, get_sample_encoder_inputs
|
|
20
|
+
|
|
21
|
+
from onnxruntime import InferenceSession
|
|
22
|
+
|
|
23
|
+
logger = logging.getLogger(__name__)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class WhisperEncoder(torch.nn.Module):
|
|
27
|
+
"""Whisper encoder component"""
|
|
28
|
+
|
|
29
|
+
def __init__(self, config: WhisperConfig, model: torch.nn.Module, model_impl: str):
|
|
30
|
+
super().__init__()
|
|
31
|
+
self.config = config
|
|
32
|
+
self.device = model.device
|
|
33
|
+
self.model_impl = model_impl
|
|
34
|
+
|
|
35
|
+
self.encoder = model.encoder if model_impl == "openai" else model.model.encoder
|
|
36
|
+
|
|
37
|
+
def forward(self, audio_features: torch.Tensor):
|
|
38
|
+
outputs = self.encoder(audio_features)
|
|
39
|
+
return outputs if self.model_impl == "openai" else outputs.last_hidden_state
|
|
40
|
+
|
|
41
|
+
def input_names(self):
|
|
42
|
+
input_names = ["audio_features"]
|
|
43
|
+
return input_names
|
|
44
|
+
|
|
45
|
+
def output_names(self):
|
|
46
|
+
output_names = ["encoder_hidden_states"]
|
|
47
|
+
return output_names
|
|
48
|
+
|
|
49
|
+
def dynamic_axes(self, input_names, output_names):
|
|
50
|
+
dynamic_axes = get_model_dynamic_axes(self.config, input_names, output_names)
|
|
51
|
+
return dynamic_axes
|
|
52
|
+
|
|
53
|
+
def fix_layernorm_weights(self, model: ModelProto, use_fp16_inputs: bool):
|
|
54
|
+
if self.model_impl == "openai" and use_fp16_inputs:
|
|
55
|
+
# Cast ONNX model to float16 to ensure LayerNorm weights are converted from
|
|
56
|
+
# float32 to float16 since exported model already has float16 weights everywhere
|
|
57
|
+
# except for LayerNorm ops. This happens because OpenAI always upcasts to float32
|
|
58
|
+
# when computing LayerNorm.
|
|
59
|
+
#
|
|
60
|
+
# Reference:
|
|
61
|
+
# https://github.com/openai/whisper/blob/90db0de1896c23cbfaf0c58bc2d30665f709f170/whisper/model.py#L41
|
|
62
|
+
model = convert_float_to_float16(model)
|
|
63
|
+
return model
|
|
64
|
+
|
|
65
|
+
def export_onnx(
|
|
66
|
+
self,
|
|
67
|
+
onnx_model_path: str,
|
|
68
|
+
provider: str,
|
|
69
|
+
verbose: bool = True,
|
|
70
|
+
use_external_data_format: bool = False,
|
|
71
|
+
use_fp16_inputs: bool = False,
|
|
72
|
+
):
|
|
73
|
+
"""Export encoder to ONNX
|
|
74
|
+
|
|
75
|
+
Args:
|
|
76
|
+
onnx_model_path (str): path to save ONNX model
|
|
77
|
+
provider (str): provider to use for verifying parity on ONNX model
|
|
78
|
+
verbose (bool, optional): print verbose information. Defaults to True.
|
|
79
|
+
use_external_data_format (bool, optional): use external data format or not. Defaults to False.
|
|
80
|
+
use_fp16_inputs (bool, optional): use float16 inputs for the audio_features. Defaults to False.
|
|
81
|
+
"""
|
|
82
|
+
# Shape of encoder's tensors:
|
|
83
|
+
# Inputs:
|
|
84
|
+
# audio_features: (batch_size, num_mels, num_frames)
|
|
85
|
+
# Outputs:
|
|
86
|
+
# encoder_hidden_states: (batch_size, num_frames // 2, hidden_size)
|
|
87
|
+
|
|
88
|
+
inputs = get_sample_encoder_inputs(
|
|
89
|
+
self.config,
|
|
90
|
+
self.device,
|
|
91
|
+
batch_size=2,
|
|
92
|
+
use_fp16=use_fp16_inputs,
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
input_names = self.input_names()
|
|
96
|
+
output_names = self.output_names()
|
|
97
|
+
dynamic_axes = self.dynamic_axes(input_names, output_names)
|
|
98
|
+
|
|
99
|
+
Path(onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
|
|
100
|
+
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
|
101
|
+
temp_onnx_model_path = os.path.join(tmp_dir_name, "encoder.onnx")
|
|
102
|
+
Path(temp_onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
|
|
103
|
+
out_path = temp_onnx_model_path if use_external_data_format else onnx_model_path
|
|
104
|
+
|
|
105
|
+
torch.onnx.export(
|
|
106
|
+
self,
|
|
107
|
+
args=(inputs["audio_features"]),
|
|
108
|
+
f=out_path,
|
|
109
|
+
export_params=True,
|
|
110
|
+
input_names=input_names,
|
|
111
|
+
output_names=output_names,
|
|
112
|
+
dynamic_axes=dynamic_axes,
|
|
113
|
+
opset_version=17,
|
|
114
|
+
do_constant_folding=True,
|
|
115
|
+
verbose=verbose,
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
model = onnx.load_model(out_path, load_external_data=use_external_data_format)
|
|
119
|
+
model = self.fix_layernorm_weights(model, use_fp16_inputs)
|
|
120
|
+
OnnxModel.save(
|
|
121
|
+
model,
|
|
122
|
+
onnx_model_path,
|
|
123
|
+
save_as_external_data=use_external_data_format,
|
|
124
|
+
all_tensors_to_one_file=True,
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
self.verify_onnx(onnx_model_path, provider, use_fp16_inputs)
|
|
128
|
+
|
|
129
|
+
def verify_onnx(
|
|
130
|
+
self,
|
|
131
|
+
onnx_model_path: str,
|
|
132
|
+
provider: str,
|
|
133
|
+
use_fp16_inputs: bool,
|
|
134
|
+
):
|
|
135
|
+
"""Verify ONNX model outputs and PyTorch model outputs match
|
|
136
|
+
|
|
137
|
+
Args:
|
|
138
|
+
onnx_model_path (str): path to save ONNX model
|
|
139
|
+
provider (str): execution provider for ONNX model
|
|
140
|
+
use_fp16_inputs (bool, optional): use float16 inputs for the audio_features
|
|
141
|
+
"""
|
|
142
|
+
# Shape of encoder's tensors:
|
|
143
|
+
# Inputs:
|
|
144
|
+
# audio_features: (batch_size, num_mels, num_frames)
|
|
145
|
+
# Outputs:
|
|
146
|
+
# encoder_hidden_states: (batch_size, num_frames // 2, hidden_size)
|
|
147
|
+
inputs = get_sample_encoder_inputs(
|
|
148
|
+
self.config,
|
|
149
|
+
self.device,
|
|
150
|
+
batch_size=2,
|
|
151
|
+
use_fp16=use_fp16_inputs,
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
# Run PyTorch model
|
|
155
|
+
pt_outputs = self.forward(inputs["audio_features"]).detach().cpu().numpy()
|
|
156
|
+
|
|
157
|
+
# Run ONNX model
|
|
158
|
+
sess = InferenceSession(onnx_model_path, providers=[provider])
|
|
159
|
+
ort_outputs = sess.run(None, {"audio_features": inputs["audio_features"].detach().cpu().numpy()})[0]
|
|
160
|
+
|
|
161
|
+
# Calculate output difference
|
|
162
|
+
diff = np.abs(pt_outputs - ort_outputs)
|
|
163
|
+
logger.warning("Comparing encoder_hidden_states...")
|
|
164
|
+
logger.warning(f"Max diff: {np.max(diff)}")
|