onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxruntime/LICENSE +21 -0
- onnxruntime/Privacy.md +21 -0
- onnxruntime/ThirdPartyNotices.txt +6121 -0
- onnxruntime/__init__.py +418 -0
- onnxruntime/backend/__init__.py +6 -0
- onnxruntime/backend/backend.py +175 -0
- onnxruntime/backend/backend_rep.py +52 -0
- onnxruntime/capi/DirectML.dll +0 -0
- onnxruntime/capi/__init__.py +4 -0
- onnxruntime/capi/_ld_preload.py +7 -0
- onnxruntime/capi/_pybind_state.py +33 -0
- onnxruntime/capi/build_and_package_info.py +2 -0
- onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
- onnxruntime/capi/onnxruntime.dll +0 -0
- onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
- onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
- onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
- onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
- onnxruntime/capi/onnxruntime_validation.py +154 -0
- onnxruntime/capi/version_info.py +2 -0
- onnxruntime/datasets/__init__.py +18 -0
- onnxruntime/datasets/logreg_iris.onnx +0 -0
- onnxruntime/datasets/mul_1.onnx +0 -0
- onnxruntime/datasets/sigmoid.onnx +13 -0
- onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
- onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
- onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
- onnxruntime/quantization/__init__.py +19 -0
- onnxruntime/quantization/base_quantizer.py +529 -0
- onnxruntime/quantization/calibrate.py +1267 -0
- onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
- onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
- onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
- onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
- onnxruntime/quantization/fusions/__init__.py +4 -0
- onnxruntime/quantization/fusions/fusion.py +311 -0
- onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
- onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
- onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
- onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
- onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
- onnxruntime/quantization/neural_compressor/__init__.py +1 -0
- onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
- onnxruntime/quantization/neural_compressor/util.py +80 -0
- onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
- onnxruntime/quantization/onnx_model.py +600 -0
- onnxruntime/quantization/onnx_quantizer.py +1163 -0
- onnxruntime/quantization/operators/__init__.py +2 -0
- onnxruntime/quantization/operators/activation.py +119 -0
- onnxruntime/quantization/operators/argmax.py +18 -0
- onnxruntime/quantization/operators/attention.py +73 -0
- onnxruntime/quantization/operators/base_operator.py +26 -0
- onnxruntime/quantization/operators/binary_op.py +72 -0
- onnxruntime/quantization/operators/concat.py +62 -0
- onnxruntime/quantization/operators/conv.py +260 -0
- onnxruntime/quantization/operators/direct_q8.py +78 -0
- onnxruntime/quantization/operators/embed_layernorm.py +121 -0
- onnxruntime/quantization/operators/gather.py +64 -0
- onnxruntime/quantization/operators/gavgpool.py +62 -0
- onnxruntime/quantization/operators/gemm.py +172 -0
- onnxruntime/quantization/operators/lstm.py +121 -0
- onnxruntime/quantization/operators/matmul.py +231 -0
- onnxruntime/quantization/operators/maxpool.py +34 -0
- onnxruntime/quantization/operators/norm.py +40 -0
- onnxruntime/quantization/operators/pad.py +172 -0
- onnxruntime/quantization/operators/pooling.py +67 -0
- onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
- onnxruntime/quantization/operators/resize.py +34 -0
- onnxruntime/quantization/operators/softmax.py +74 -0
- onnxruntime/quantization/operators/split.py +63 -0
- onnxruntime/quantization/operators/where.py +87 -0
- onnxruntime/quantization/preprocess.py +141 -0
- onnxruntime/quantization/qdq_loss_debug.py +389 -0
- onnxruntime/quantization/qdq_quantizer.py +1477 -0
- onnxruntime/quantization/quant_utils.py +1051 -0
- onnxruntime/quantization/quantize.py +953 -0
- onnxruntime/quantization/registry.py +110 -0
- onnxruntime/quantization/shape_inference.py +204 -0
- onnxruntime/quantization/static_quantize_runner.py +256 -0
- onnxruntime/quantization/tensor_quant_overrides.py +520 -0
- onnxruntime/tools/__init__.py +10 -0
- onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
- onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
- onnxruntime/tools/file_utils.py +47 -0
- onnxruntime/tools/logger.py +11 -0
- onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
- onnxruntime/tools/mobile_helpers/__init__.py +0 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
- onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
- onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
- onnxruntime/tools/offline_tuning.py +169 -0
- onnxruntime/tools/onnx_model_utils.py +416 -0
- onnxruntime/tools/onnx_randomizer.py +85 -0
- onnxruntime/tools/onnxruntime_test.py +164 -0
- onnxruntime/tools/optimize_onnx_model.py +56 -0
- onnxruntime/tools/ort_format_model/__init__.py +27 -0
- onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
- onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
- onnxruntime/tools/ort_format_model/types.py +85 -0
- onnxruntime/tools/ort_format_model/utils.py +61 -0
- onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
- onnxruntime/tools/pytorch_export_helpers.py +131 -0
- onnxruntime/tools/qdq_helpers/__init__.py +0 -0
- onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
- onnxruntime/tools/qnn/add_trans_cast.py +292 -0
- onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
- onnxruntime/tools/qnn/preprocess.py +165 -0
- onnxruntime/tools/reduced_build_config_parser.py +203 -0
- onnxruntime/tools/remove_initializer_from_input.py +37 -0
- onnxruntime/tools/symbolic_shape_infer.py +3094 -0
- onnxruntime/tools/update_onnx_opset.py +31 -0
- onnxruntime/transformers/__init__.py +8 -0
- onnxruntime/transformers/affinity_helper.py +40 -0
- onnxruntime/transformers/benchmark.py +942 -0
- onnxruntime/transformers/benchmark_helper.py +643 -0
- onnxruntime/transformers/bert_perf_test.py +629 -0
- onnxruntime/transformers/bert_test_data.py +641 -0
- onnxruntime/transformers/compare_bert_results.py +256 -0
- onnxruntime/transformers/constants.py +47 -0
- onnxruntime/transformers/convert_generation.py +3605 -0
- onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
- onnxruntime/transformers/convert_to_packing_mode.py +385 -0
- onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
- onnxruntime/transformers/float16.py +501 -0
- onnxruntime/transformers/fusion_attention.py +1189 -0
- onnxruntime/transformers/fusion_attention_clip.py +340 -0
- onnxruntime/transformers/fusion_attention_sam2.py +533 -0
- onnxruntime/transformers/fusion_attention_unet.py +1307 -0
- onnxruntime/transformers/fusion_attention_vae.py +300 -0
- onnxruntime/transformers/fusion_bart_attention.py +435 -0
- onnxruntime/transformers/fusion_base.py +141 -0
- onnxruntime/transformers/fusion_bias_add.py +57 -0
- onnxruntime/transformers/fusion_biasgelu.py +66 -0
- onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
- onnxruntime/transformers/fusion_conformer_attention.py +222 -0
- onnxruntime/transformers/fusion_constant_fold.py +144 -0
- onnxruntime/transformers/fusion_embedlayer.py +810 -0
- onnxruntime/transformers/fusion_fastgelu.py +492 -0
- onnxruntime/transformers/fusion_gelu.py +258 -0
- onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
- onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
- onnxruntime/transformers/fusion_gpt_attention.py +546 -0
- onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
- onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
- onnxruntime/transformers/fusion_group_norm.py +180 -0
- onnxruntime/transformers/fusion_layernorm.py +489 -0
- onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
- onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
- onnxruntime/transformers/fusion_options.py +340 -0
- onnxruntime/transformers/fusion_qordered_attention.py +420 -0
- onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
- onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
- onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
- onnxruntime/transformers/fusion_quickgelu.py +74 -0
- onnxruntime/transformers/fusion_reshape.py +173 -0
- onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
- onnxruntime/transformers/fusion_shape.py +109 -0
- onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
- onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
- onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
- onnxruntime/transformers/fusion_transpose.py +167 -0
- onnxruntime/transformers/fusion_utils.py +321 -0
- onnxruntime/transformers/huggingface_models.py +74 -0
- onnxruntime/transformers/import_utils.py +20 -0
- onnxruntime/transformers/io_binding_helper.py +487 -0
- onnxruntime/transformers/large_model_exporter.py +395 -0
- onnxruntime/transformers/machine_info.py +230 -0
- onnxruntime/transformers/metrics.py +163 -0
- onnxruntime/transformers/models/bart/__init__.py +12 -0
- onnxruntime/transformers/models/bart/export.py +98 -0
- onnxruntime/transformers/models/bert/__init__.py +12 -0
- onnxruntime/transformers/models/bert/eval_squad.py +329 -0
- onnxruntime/transformers/models/gpt2/__init__.py +12 -0
- onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
- onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
- onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
- onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
- onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
- onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
- onnxruntime/transformers/models/llama/__init__.py +12 -0
- onnxruntime/transformers/models/llama/benchmark.py +700 -0
- onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
- onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
- onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
- onnxruntime/transformers/models/llama/dist_settings.py +57 -0
- onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
- onnxruntime/transformers/models/llama/llama_parity.py +343 -0
- onnxruntime/transformers/models/llama/llama_torch.py +47 -0
- onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
- onnxruntime/transformers/models/longformer/__init__.py +12 -0
- onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
- onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
- onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
- onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
- onnxruntime/transformers/models/phi2/__init__.py +12 -0
- onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
- onnxruntime/transformers/models/phi2/inference_example.py +414 -0
- onnxruntime/transformers/models/sam2/__init__.py +12 -0
- onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
- onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
- onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
- onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
- onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
- onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
- onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
- onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
- onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
- onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
- onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
- onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
- onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
- onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
- onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
- onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
- onnxruntime/transformers/models/t5/__init__.py +12 -0
- onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
- onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
- onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
- onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
- onnxruntime/transformers/models/t5/t5_helper.py +302 -0
- onnxruntime/transformers/models/whisper/__init__.py +12 -0
- onnxruntime/transformers/models/whisper/benchmark.py +585 -0
- onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
- onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
- onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
- onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
- onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
- onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
- onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
- onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
- onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
- onnxruntime/transformers/onnx_exporter.py +719 -0
- onnxruntime/transformers/onnx_model.py +1636 -0
- onnxruntime/transformers/onnx_model_bart.py +141 -0
- onnxruntime/transformers/onnx_model_bert.py +488 -0
- onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
- onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
- onnxruntime/transformers/onnx_model_clip.py +42 -0
- onnxruntime/transformers/onnx_model_conformer.py +32 -0
- onnxruntime/transformers/onnx_model_gpt2.py +101 -0
- onnxruntime/transformers/onnx_model_mmdit.py +112 -0
- onnxruntime/transformers/onnx_model_phi.py +929 -0
- onnxruntime/transformers/onnx_model_sam2.py +137 -0
- onnxruntime/transformers/onnx_model_t5.py +985 -0
- onnxruntime/transformers/onnx_model_tnlr.py +226 -0
- onnxruntime/transformers/onnx_model_unet.py +258 -0
- onnxruntime/transformers/onnx_model_vae.py +42 -0
- onnxruntime/transformers/onnx_utils.py +55 -0
- onnxruntime/transformers/optimizer.py +620 -0
- onnxruntime/transformers/past_helper.py +149 -0
- onnxruntime/transformers/profile_result_processor.py +358 -0
- onnxruntime/transformers/profiler.py +434 -0
- onnxruntime/transformers/quantize_helper.py +76 -0
- onnxruntime/transformers/shape_infer_helper.py +121 -0
- onnxruntime/transformers/shape_optimizer.py +400 -0
- onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
- onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
- onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
- onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
- onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
- onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,477 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License. See License.txt in the project root for
|
|
4
|
+
# license information.
|
|
5
|
+
# --------------------------------------------------------------------------
|
|
6
|
+
|
|
7
|
+
import logging
|
|
8
|
+
import os
|
|
9
|
+
import tempfile
|
|
10
|
+
import textwrap
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
|
|
13
|
+
import numpy as np
|
|
14
|
+
import onnx
|
|
15
|
+
import torch
|
|
16
|
+
import torch.nn.functional as F
|
|
17
|
+
import torch.utils.cpp_extension
|
|
18
|
+
from onnx_model import OnnxModel
|
|
19
|
+
from transformers import WhisperConfig
|
|
20
|
+
from whisper_inputs import convert_inputs_for_ort, get_model_dynamic_axes, get_sample_jump_times_inputs
|
|
21
|
+
|
|
22
|
+
from onnxruntime import InferenceSession
|
|
23
|
+
from onnxruntime.tools import pytorch_export_contrib_ops
|
|
24
|
+
|
|
25
|
+
logger = logging.getLogger(__name__)
|
|
26
|
+
|
|
27
|
+
##################################################
|
|
28
|
+
# Functions that have to be outside of the class
|
|
29
|
+
# for torch.jit.script_if_tracing to work
|
|
30
|
+
##################################################
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@torch.jit.script_if_tracing
|
|
34
|
+
def index_QKs(alignment_heads: torch.Tensor, QKs: list[torch.Tensor]): # noqa: N802
|
|
35
|
+
"""
|
|
36
|
+
Compute the following to get stacked QK tensor that has been indexed for the desired attention heads:
|
|
37
|
+
weights = torch.stack([QKs[_l][:, _h] for _l, _h in alignment_heads], dim=1)
|
|
38
|
+
"""
|
|
39
|
+
indexed_QKs = [] # noqa: N806
|
|
40
|
+
for pair in alignment_heads:
|
|
41
|
+
# Each QK is of shape (batch_size, num_heads, sequence_length, num_frames // 2)
|
|
42
|
+
# The `QKs[_l]` selects the right QK from the list of QKs
|
|
43
|
+
# The `QKs[_l][:, _h]` selects the right attention heads from the chosen QK. The `:` is to do this for the batch dim.
|
|
44
|
+
#
|
|
45
|
+
# PyTorch:
|
|
46
|
+
# QKs[_l] is of shape (batch_size, num_heads, sequence_length, num_frames // 2)
|
|
47
|
+
# QKs[_l][:, _h] is of shape (batch_size, sequence_length, num_frames // 2)
|
|
48
|
+
#
|
|
49
|
+
# ONNX:
|
|
50
|
+
# QKs[_l] is of shape (batch_size, num_heads, sequence_length, num_frames // 2)
|
|
51
|
+
# QKs[_l][:, _h] is of shape (batch_size, 1, sequence_length, num_frames // 2) because
|
|
52
|
+
# the `[:, _h]` operation maps to a Gather op and that op does not reduce dimensions
|
|
53
|
+
_l, _h = pair[0], pair[1]
|
|
54
|
+
indexed_QKs.append(QKs[_l][:, _h])
|
|
55
|
+
|
|
56
|
+
# PyTorch:
|
|
57
|
+
# torch.stack will return a tensor of shape (batch_size, num_alignment_heads, sequence_length, num_frames // 2).
|
|
58
|
+
#
|
|
59
|
+
# ONNX:
|
|
60
|
+
# torch.stack will return a tensor of shape (batch_size, num_alignment_heads, 1, sequence_length, num_frames // 2)
|
|
61
|
+
# because the Gather op does not reduce dimensions. To remove the unneeded dimension, torch.squeeze with a specified
|
|
62
|
+
# dim (dim = 2) is added. The torch.squeeze op with a specified dim only runs if the specified dim has a size of 1.
|
|
63
|
+
# Since the dim won't be of size 1 in the PyTorch tensor but it is of size 1 in the ONNX tensor, it will be a no-op
|
|
64
|
+
# in PyTorch and an op in ONNX. Thus, the Squeeze op will only affect the ONNX model.
|
|
65
|
+
weights = torch.stack(indexed_QKs, dim=1)
|
|
66
|
+
weights = torch.squeeze(weights, dim=2)
|
|
67
|
+
return weights
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def jump_timings(text_indices, time_indices):
|
|
71
|
+
"""
|
|
72
|
+
Calculate jump times from text_indices and time_indices where
|
|
73
|
+
text_indices and time_indices are both 1d vectors
|
|
74
|
+
"""
|
|
75
|
+
TOKENS_PER_SECOND = 50.0 # noqa: N806
|
|
76
|
+
diff = text_indices[1:] - text_indices[:-1]
|
|
77
|
+
padding = torch.tensor([1], dtype=torch.int32)
|
|
78
|
+
jumps = torch.cat((padding, diff)).to(torch.bool)
|
|
79
|
+
jump_times = time_indices[jumps].to(torch.float) / TOKENS_PER_SECOND
|
|
80
|
+
return jump_times
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def padded_jump_from_dtw(matrix_2d: torch.Tensor, max_length: torch.Tensor):
|
|
84
|
+
"""
|
|
85
|
+
Run Dynamic Time Warping (DTW) on batched tensor
|
|
86
|
+
"""
|
|
87
|
+
trace = torch.ops.onnxruntime.DynamicTimeWarping(matrix_2d)
|
|
88
|
+
text_indices = trace[0, :]
|
|
89
|
+
time_indices = trace[1, :]
|
|
90
|
+
jump_times = jump_timings(text_indices, time_indices)
|
|
91
|
+
return F.pad(jump_times, [0, int((max_length - jump_times.size(-1)).item())], mode="constant", value=-1.0)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
@torch.jit.script_if_tracing
|
|
95
|
+
def batch_jump_times(matrix: torch.Tensor, max_decoded_length: torch.Tensor):
|
|
96
|
+
"""
|
|
97
|
+
Compute the following to calculate jump times for all batches:
|
|
98
|
+
batched_jump_times = torch.stack([self.padded_jump_from_dtw(matrix[b], max_decoded_length) for b in range(matrix.size(0))])
|
|
99
|
+
"""
|
|
100
|
+
list_of_jump_times = []
|
|
101
|
+
for b in range(matrix.size(0)):
|
|
102
|
+
jump_times = padded_jump_from_dtw(matrix[b], max_decoded_length)
|
|
103
|
+
list_of_jump_times.append(jump_times)
|
|
104
|
+
batched_jump_times = torch.stack(list_of_jump_times)
|
|
105
|
+
return batched_jump_times
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
class WhisperJumpTimes(torch.nn.Module):
|
|
109
|
+
"""Whisper jump times component"""
|
|
110
|
+
|
|
111
|
+
def __init__(self, config: WhisperConfig, device: torch.device, cache_dir: str | os.PathLike):
|
|
112
|
+
super().__init__()
|
|
113
|
+
self.config = config
|
|
114
|
+
self.device = device
|
|
115
|
+
self.cache_dir = cache_dir
|
|
116
|
+
|
|
117
|
+
self.filter_width = 7
|
|
118
|
+
self.qk_scale = 1.0
|
|
119
|
+
|
|
120
|
+
def median_filter(self, weights: torch.Tensor):
|
|
121
|
+
"""
|
|
122
|
+
Apply a median filter of width `filter_width` along the last dimension of `weights`
|
|
123
|
+
"""
|
|
124
|
+
pad_width = self.filter_width // 2
|
|
125
|
+
x = F.pad(weights, (pad_width, pad_width, 0, 0), mode="reflect")
|
|
126
|
+
x_unfolded = torch.ops.onnxruntime.UnfoldTensor(x, -1, self.filter_width, 1)
|
|
127
|
+
result = torch.select(x_unfolded.sort()[0], dim=-1, index=pad_width)
|
|
128
|
+
return result
|
|
129
|
+
|
|
130
|
+
def forward(
|
|
131
|
+
self,
|
|
132
|
+
alignment_heads: torch.Tensor,
|
|
133
|
+
sot_sequence_length: torch.Tensor,
|
|
134
|
+
segment_length: torch.Tensor,
|
|
135
|
+
QKs: list[torch.Tensor],
|
|
136
|
+
):
|
|
137
|
+
# Get stacked QKs tensor
|
|
138
|
+
weights = index_QKs(alignment_heads, QKs)
|
|
139
|
+
weights = weights[:, :, : segment_length // 2]
|
|
140
|
+
weights = weights.to(torch.float32)
|
|
141
|
+
|
|
142
|
+
weights = (weights * self.qk_scale).softmax(dim=-1)
|
|
143
|
+
std, mean = torch.std_mean(weights, dim=-2, keepdim=True, unbiased=False)
|
|
144
|
+
weights = (weights - mean) / std
|
|
145
|
+
weights = self.median_filter(weights)
|
|
146
|
+
|
|
147
|
+
matrix = torch.mean(weights, 1)
|
|
148
|
+
matrix = -matrix[:, sot_sequence_length:-1]
|
|
149
|
+
|
|
150
|
+
max_decoded_length = torch.tensor([matrix.size(1)], dtype=torch.int64)
|
|
151
|
+
batched_jump_times = batch_jump_times(matrix, max_decoded_length)
|
|
152
|
+
return batched_jump_times
|
|
153
|
+
|
|
154
|
+
def input_names(self):
|
|
155
|
+
input_names = [
|
|
156
|
+
"alignment_heads",
|
|
157
|
+
"sot_sequence_length",
|
|
158
|
+
"segment_length",
|
|
159
|
+
*[f"cross_qk_{i}" for i in range(self.config.decoder_layers)],
|
|
160
|
+
]
|
|
161
|
+
return input_names
|
|
162
|
+
|
|
163
|
+
def output_names(self):
|
|
164
|
+
output_names = ["jump_times"]
|
|
165
|
+
return output_names
|
|
166
|
+
|
|
167
|
+
def inputs(self, use_fp16_inputs: bool, use_int32_inputs: bool, return_dict: bool = False):
|
|
168
|
+
inputs = get_sample_jump_times_inputs(
|
|
169
|
+
self.config,
|
|
170
|
+
self.device,
|
|
171
|
+
batch_size=2,
|
|
172
|
+
sequence_length=8,
|
|
173
|
+
num_alignment_heads=6,
|
|
174
|
+
sot_sequence_length=3,
|
|
175
|
+
segment_length=1332,
|
|
176
|
+
use_fp16=use_fp16_inputs,
|
|
177
|
+
use_int32=use_int32_inputs,
|
|
178
|
+
)
|
|
179
|
+
if return_dict:
|
|
180
|
+
return inputs
|
|
181
|
+
return (
|
|
182
|
+
inputs["alignment_heads"],
|
|
183
|
+
inputs["sot_sequence_length"],
|
|
184
|
+
inputs["segment_length"],
|
|
185
|
+
inputs["QKs"],
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
def create_torch_ops(self):
|
|
189
|
+
"""
|
|
190
|
+
1) Create UnfoldTensor and DynamicTimeWarping as torch ops
|
|
191
|
+
3) Provide a symbolic mapping from torch ops to ORT contrib ops
|
|
192
|
+
|
|
193
|
+
See https://pytorch.org/tutorials/advanced/torch_script_custom_ops.html#building-with-jit-compilation
|
|
194
|
+
for more details on how this works.
|
|
195
|
+
"""
|
|
196
|
+
# Set torch extensions directory to cache directory
|
|
197
|
+
os.environ["TORCH_EXTENSIONS_DIR"] = self.cache_dir
|
|
198
|
+
|
|
199
|
+
# Try to import `ninja` pip package
|
|
200
|
+
try:
|
|
201
|
+
assert torch.utils.cpp_extension.verify_ninja_availability()
|
|
202
|
+
except Exception as e:
|
|
203
|
+
logger.error(f"An error occurred while verifying `ninja` is available: {e}", exc_info=True) # noqa: G201
|
|
204
|
+
install_cmd = "pip install ninja"
|
|
205
|
+
logger.warning(f"Could not import `ninja`. Attempting to install `ninja` via `{install_cmd}`.")
|
|
206
|
+
os.system(install_cmd)
|
|
207
|
+
|
|
208
|
+
# Create UnfoldTensor torch op
|
|
209
|
+
unfold_op_source = textwrap.dedent("""\
|
|
210
|
+
#include "torch/script.h"
|
|
211
|
+
|
|
212
|
+
torch::Tensor UnfoldTensor(torch::Tensor input, int64_t dim, int64_t size, int64_t step) {
|
|
213
|
+
return input.unfold(dim, size, step);
|
|
214
|
+
}
|
|
215
|
+
|
|
216
|
+
// namespace is onnxruntime
|
|
217
|
+
static auto registry = torch::RegisterOperators("onnxruntime::UnfoldTensor", &UnfoldTensor);
|
|
218
|
+
""")
|
|
219
|
+
|
|
220
|
+
torch.utils.cpp_extension.load_inline(
|
|
221
|
+
name="UnfoldTensor",
|
|
222
|
+
cpp_sources=unfold_op_source,
|
|
223
|
+
is_python_module=False,
|
|
224
|
+
verbose=True,
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
# Create DynamicTimeWarping torch op
|
|
228
|
+
dtw_op_source = textwrap.dedent("""\
|
|
229
|
+
#include "torch/script.h"
|
|
230
|
+
#include "torch/torch.h"
|
|
231
|
+
#include <stdexcept>
|
|
232
|
+
#include <tuple>
|
|
233
|
+
#include <vector>
|
|
234
|
+
|
|
235
|
+
torch::Tensor Backtrace(torch::Tensor trace) {
|
|
236
|
+
int64_t i = trace.size(0) - 1;
|
|
237
|
+
int64_t j = trace.size(1) - 1;
|
|
238
|
+
trace.index({0, torch::indexing::Slice()}) = 2;
|
|
239
|
+
trace.index({torch::indexing::Slice(), 0}) = 1;
|
|
240
|
+
|
|
241
|
+
std::vector<int32_t> result_vec;
|
|
242
|
+
while (i > 0 || j > 0) {
|
|
243
|
+
result_vec.push_back(static_cast<int32_t>(i - 1));
|
|
244
|
+
result_vec.push_back(static_cast<int32_t>(j - 1));
|
|
245
|
+
int value = trace[i][j].item<int>();
|
|
246
|
+
|
|
247
|
+
if (value == 0) {
|
|
248
|
+
i--;
|
|
249
|
+
j--;
|
|
250
|
+
} else if (value == 1) {
|
|
251
|
+
i--;
|
|
252
|
+
} else if (value == 2) {
|
|
253
|
+
j--;
|
|
254
|
+
} else {
|
|
255
|
+
throw std::runtime_error("Unexpected trace[i, j]");
|
|
256
|
+
}
|
|
257
|
+
}
|
|
258
|
+
|
|
259
|
+
// Compute result[::-1, :].T
|
|
260
|
+
torch::Tensor result = torch::from_blob(result_vec.data(), {static_cast<long int>(result_vec.size() / 2), 2}, torch::kInt32).clone();
|
|
261
|
+
torch::Tensor reversed = result.flip(0); // result[::-1, :]
|
|
262
|
+
torch::Tensor transposed = reversed.transpose(0, 1); // .T
|
|
263
|
+
return transposed;
|
|
264
|
+
}
|
|
265
|
+
|
|
266
|
+
torch::Tensor DynamicTimeWarping(torch::Tensor x) {
|
|
267
|
+
int64_t N = x.size(0);
|
|
268
|
+
int64_t M = x.size(1);
|
|
269
|
+
torch::Tensor cost = torch::full({N + 1, M + 1}, std::numeric_limits<float>::infinity(), torch::dtype(torch::kFloat32));
|
|
270
|
+
torch::Tensor trace = torch::full({N + 1, M + 1}, -1, torch::dtype(torch::kFloat32));
|
|
271
|
+
|
|
272
|
+
cost[0][0] = 0;
|
|
273
|
+
for (int j = 1; j < M + 1; j++) {
|
|
274
|
+
for (int i = 1; i < N + 1; i++) {
|
|
275
|
+
float c0 = cost[i - 1][j - 1].item<float>();
|
|
276
|
+
float c1 = cost[i - 1][j].item<float>();
|
|
277
|
+
float c2 = cost[i][j - 1].item<float>();
|
|
278
|
+
|
|
279
|
+
float c = 0;
|
|
280
|
+
float t = 0;
|
|
281
|
+
|
|
282
|
+
if (c0 < c1 && c0 < c2) {
|
|
283
|
+
c = c0;
|
|
284
|
+
t = 0;
|
|
285
|
+
} else if (c1 < c0 && c1 < c2) {
|
|
286
|
+
c = c1;
|
|
287
|
+
t = 1;
|
|
288
|
+
} else {
|
|
289
|
+
c = c2;
|
|
290
|
+
t = 2;
|
|
291
|
+
}
|
|
292
|
+
|
|
293
|
+
cost[i][j] = x[i - 1][j - 1].item<float>() + c;
|
|
294
|
+
trace[i][j] = t;
|
|
295
|
+
}
|
|
296
|
+
}
|
|
297
|
+
|
|
298
|
+
return Backtrace(trace);
|
|
299
|
+
}
|
|
300
|
+
|
|
301
|
+
// namespace is onnxruntime
|
|
302
|
+
static auto registry = torch::RegisterOperators("onnxruntime::DynamicTimeWarping", &DynamicTimeWarping);
|
|
303
|
+
""")
|
|
304
|
+
|
|
305
|
+
torch.utils.cpp_extension.load_inline(
|
|
306
|
+
name="DynamicTimeWarping",
|
|
307
|
+
cpp_sources=dtw_op_source,
|
|
308
|
+
is_python_module=False,
|
|
309
|
+
verbose=True,
|
|
310
|
+
)
|
|
311
|
+
|
|
312
|
+
# Create symbolic mapping from torch ops to ORT contrib ops
|
|
313
|
+
pytorch_export_contrib_ops.register()
|
|
314
|
+
|
|
315
|
+
def export_onnx(
|
|
316
|
+
self,
|
|
317
|
+
onnx_model_path: str,
|
|
318
|
+
provider: str,
|
|
319
|
+
verbose: bool = True,
|
|
320
|
+
use_external_data_format: bool = False,
|
|
321
|
+
use_fp16_inputs: bool = False,
|
|
322
|
+
use_int32_inputs: bool = True,
|
|
323
|
+
):
|
|
324
|
+
"""Export word-level timestamps to ONNX
|
|
325
|
+
|
|
326
|
+
Args:
|
|
327
|
+
onnx_model_path (str): path to save ONNX model
|
|
328
|
+
provider (str): provider to use for verifying parity on ONNX model
|
|
329
|
+
verbose (bool, optional): print verbose information. Defaults to True.
|
|
330
|
+
use_external_data_format (bool, optional): use external data format or not. Defaults to False.
|
|
331
|
+
use_fp16_inputs (bool, optional): use float16 inputs for the audio_features. Defaults to False.
|
|
332
|
+
use_int32_inputs (bool, optional): use int32 inputs for the decoder_input_ids. Defaults to True.
|
|
333
|
+
"""
|
|
334
|
+
# Shape of timestamps's tensors:
|
|
335
|
+
# Inputs:
|
|
336
|
+
# alignment_heads: (num_alignment_heads, 2)
|
|
337
|
+
# sot_sequence_length: (1)
|
|
338
|
+
# segment_length: (1)
|
|
339
|
+
# cross_qk_*: (batch_size, num_heads, sequence_length, num_frames // 2)
|
|
340
|
+
# Outputs:
|
|
341
|
+
# jump_times: (batch_size, max_length)
|
|
342
|
+
|
|
343
|
+
# Definitions:
|
|
344
|
+
# alignment_heads: the attention head indices where the Q*K values are highly correlated with word-level timestamps
|
|
345
|
+
# (i.e. the alignment between audio and text tokens)
|
|
346
|
+
# This is calculated as follows:
|
|
347
|
+
#
|
|
348
|
+
# ```
|
|
349
|
+
# import base64
|
|
350
|
+
# import gzip
|
|
351
|
+
# import numpy as np
|
|
352
|
+
# import torch
|
|
353
|
+
#
|
|
354
|
+
# # base85-encoded (n_layers, n_heads) boolean arrays indicating the cross-attention heads that are
|
|
355
|
+
# # highly correlated to the word-level timing, i.e. the alignment between audio and text tokens.
|
|
356
|
+
# _ALIGNMENT_HEADS = {
|
|
357
|
+
# "tiny.en": b"ABzY8J1N>@0{>%R00Bk>$p{7v037`oCl~+#00",
|
|
358
|
+
# "tiny": b"ABzY8bu8Lr0{>%RKn9Fp%m@SkK7Kt=7ytkO",
|
|
359
|
+
# "base.en": b"ABzY8;40c<0{>%RzzG;p*o+Vo09|#PsxSZm00",
|
|
360
|
+
# "base": b"ABzY8KQ!870{>%RzyTQH3`Q^yNP!>##QT-<FaQ7m",
|
|
361
|
+
# "small.en": b"ABzY8>?_)10{>%RpeA61k&I|OI3I$65C{;;pbCHh0B{qLQ;+}v00",
|
|
362
|
+
# "small": b"ABzY8DmU6=0{>%Rpa?J`kvJ6qF(V^F86#Xh7JUGMK}P<N0000",
|
|
363
|
+
# "medium.en": b"ABzY8usPae0{>%R7<zz_OvQ{)4kMa0BMw6u5rT}kRKX;$NfYBv00*Hl@qhsU00",
|
|
364
|
+
# "medium": b"ABzY8B0Jh+0{>%R7}kK1fFL7w6%<-Pf*t^=N)Qr&0RR9",
|
|
365
|
+
# "large-v1": b"ABzY8r9j$a0{>%R7#4sLmoOs{s)o3~84-RPdcFk!JR<kSfC2yj",
|
|
366
|
+
# "large-v2": b"ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj",
|
|
367
|
+
# "large-v3": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
|
|
368
|
+
# "large": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
|
|
369
|
+
# "large-v3-turbo": b"ABzY8j^C+e0{>%RARaKHP%t(lGR*)0g!tONPyhe`",
|
|
370
|
+
# "turbo": b"ABzY8j^C+e0{>%RARaKHP%t(lGR*)0g!tONPyhe`",
|
|
371
|
+
# }
|
|
372
|
+
#
|
|
373
|
+
# model_name = "large-v3-turbo"
|
|
374
|
+
# array = np.frombuffer(
|
|
375
|
+
# gzip.decompress(base64.b85decode(_ALIGNMENT_HEADS[model_name])), dtype=bool
|
|
376
|
+
# ).copy()
|
|
377
|
+
# mask = torch.from_numpy(array).reshape(
|
|
378
|
+
# self.dims.n_text_layer, self.dims.n_text_head
|
|
379
|
+
# )
|
|
380
|
+
# self.alignment_heads = mask.to_sparse().indices().T
|
|
381
|
+
# ```
|
|
382
|
+
#
|
|
383
|
+
# sot_sequence_length: the length of the start-of-transcription sequence before the first token is generated
|
|
384
|
+
# Typically the start-of-transcription sequence is [<|startoftranscription|>, <|language_token|>, <|task_token|>]
|
|
385
|
+
# so its length is 3.
|
|
386
|
+
#
|
|
387
|
+
# segment_length: the length (in frames) of the audio segment that is being transcribed
|
|
388
|
+
#
|
|
389
|
+
# cross_qk_*: the Q*K values for the cross-attention blocks in the decoder
|
|
390
|
+
# Every decoder layer has a self-attention block and a cross-attention block so there are `n` cross-attention blocks
|
|
391
|
+
# where `n` is the number of decoder layers.
|
|
392
|
+
#
|
|
393
|
+
# jump_times: the timings where jumps occur in speech
|
|
394
|
+
# This allows us to detect when a word began to be spoken by the speaker (start_times) and when a word was finished
|
|
395
|
+
# being spoken by the speaker (end_times).
|
|
396
|
+
|
|
397
|
+
inputs = self.inputs(use_fp16_inputs=use_fp16_inputs, use_int32_inputs=use_int32_inputs)
|
|
398
|
+
input_names = self.input_names()
|
|
399
|
+
output_names = self.output_names()
|
|
400
|
+
dynamic_axes = get_model_dynamic_axes(self.config, input_names, output_names)
|
|
401
|
+
|
|
402
|
+
Path(onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
|
|
403
|
+
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
|
404
|
+
temp_onnx_model_path = os.path.join(tmp_dir_name, "encoder.onnx")
|
|
405
|
+
Path(temp_onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
|
|
406
|
+
out_path = temp_onnx_model_path if use_external_data_format else onnx_model_path
|
|
407
|
+
|
|
408
|
+
# Create torch ops and map them to ORT contrib ops before export
|
|
409
|
+
self.create_torch_ops()
|
|
410
|
+
torch.onnx.export(
|
|
411
|
+
self,
|
|
412
|
+
args=inputs,
|
|
413
|
+
f=out_path,
|
|
414
|
+
export_params=True,
|
|
415
|
+
input_names=input_names,
|
|
416
|
+
output_names=output_names,
|
|
417
|
+
dynamic_axes=dynamic_axes,
|
|
418
|
+
opset_version=17,
|
|
419
|
+
do_constant_folding=True,
|
|
420
|
+
verbose=verbose,
|
|
421
|
+
custom_opsets={"com.microsoft": 1},
|
|
422
|
+
)
|
|
423
|
+
|
|
424
|
+
if use_external_data_format:
|
|
425
|
+
model = onnx.load_model(out_path, load_external_data=use_external_data_format)
|
|
426
|
+
OnnxModel.save(
|
|
427
|
+
model,
|
|
428
|
+
onnx_model_path,
|
|
429
|
+
save_as_external_data=True,
|
|
430
|
+
all_tensors_to_one_file=True,
|
|
431
|
+
)
|
|
432
|
+
|
|
433
|
+
self.verify_onnx(onnx_model_path, provider, use_fp16_inputs, use_int32_inputs)
|
|
434
|
+
|
|
435
|
+
def verify_onnx(
|
|
436
|
+
self,
|
|
437
|
+
onnx_model_path: str,
|
|
438
|
+
provider: str,
|
|
439
|
+
use_fp16_inputs: bool,
|
|
440
|
+
use_int32_inputs: bool,
|
|
441
|
+
):
|
|
442
|
+
"""Verify ONNX model outputs and PyTorch model outputs match
|
|
443
|
+
|
|
444
|
+
Args:
|
|
445
|
+
onnx_model_path (str): path to save ONNX model
|
|
446
|
+
provider (str): execution provider for ONNX model
|
|
447
|
+
use_fp16_inputs (bool, optional): use float16 inputs for the cross_qk_{i}
|
|
448
|
+
use_int32_inputs (bool, optional): use int32 inputs for the alignment_heads and sot_sequence_length
|
|
449
|
+
"""
|
|
450
|
+
# Shape of jump times's tensors:
|
|
451
|
+
# Inputs:
|
|
452
|
+
# alignment_heads: (num_alignment_heads, 2)
|
|
453
|
+
# sot_sequence_length: (1)
|
|
454
|
+
# segment_length: (1)
|
|
455
|
+
# cross_qk_*: (batch_size, num_heads, sequence_length, num_frames // 2)
|
|
456
|
+
# Outputs:
|
|
457
|
+
# jump_times: (batch_size, max_length)
|
|
458
|
+
inputs = self.inputs(use_fp16_inputs=use_fp16_inputs, use_int32_inputs=use_int32_inputs, return_dict=True)
|
|
459
|
+
|
|
460
|
+
# Run PyTorch model
|
|
461
|
+
pt_outputs = (
|
|
462
|
+
self.forward(
|
|
463
|
+
inputs["alignment_heads"], inputs["sot_sequence_length"], inputs["segment_length"], inputs["QKs"]
|
|
464
|
+
)
|
|
465
|
+
.detach()
|
|
466
|
+
.cpu()
|
|
467
|
+
.numpy()
|
|
468
|
+
)
|
|
469
|
+
|
|
470
|
+
# Run ONNX model
|
|
471
|
+
sess = InferenceSession(onnx_model_path, providers=[provider])
|
|
472
|
+
ort_outputs = sess.run(None, convert_inputs_for_ort(inputs, sess))
|
|
473
|
+
|
|
474
|
+
# Calculate output difference
|
|
475
|
+
diff = np.abs(pt_outputs - ort_outputs)
|
|
476
|
+
print("Comparing batched jump_times...", flush=True)
|
|
477
|
+
print(f"Max diff: {np.max(diff)}", flush=True)
|