onnxruntime-directml 1.20.0__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxruntime/LICENSE +21 -0
- onnxruntime/Privacy.md +21 -0
- onnxruntime/ThirdPartyNotices.txt +6508 -0
- onnxruntime/__init__.py +78 -0
- onnxruntime/backend/__init__.py +6 -0
- onnxruntime/backend/backend.py +174 -0
- onnxruntime/backend/backend_rep.py +53 -0
- onnxruntime/capi/DirectML.dll +0 -0
- onnxruntime/capi/__init__.py +4 -0
- onnxruntime/capi/_ld_preload.py +7 -0
- onnxruntime/capi/_pybind_state.py +33 -0
- onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
- onnxruntime/capi/onnxruntime.dll +0 -0
- onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
- onnxruntime/capi/onnxruntime_inference_collection.py +1108 -0
- onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
- onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
- onnxruntime/capi/onnxruntime_validation.py +150 -0
- onnxruntime/capi/version_info.py +2 -0
- onnxruntime/datasets/__init__.py +17 -0
- onnxruntime/datasets/logreg_iris.onnx +0 -0
- onnxruntime/datasets/mul_1.onnx +0 -0
- onnxruntime/datasets/sigmoid.onnx +13 -0
- onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
- onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
- onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
- onnxruntime/quantization/__init__.py +16 -0
- onnxruntime/quantization/base_quantizer.py +532 -0
- onnxruntime/quantization/calibrate.py +1245 -0
- onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
- onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
- onnxruntime/quantization/execution_providers/qnn/preprocess.py +307 -0
- onnxruntime/quantization/execution_providers/qnn/quant_config.py +387 -0
- onnxruntime/quantization/fusions/__init__.py +3 -0
- onnxruntime/quantization/fusions/fusion.py +311 -0
- onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
- onnxruntime/quantization/fusions/fusion_layernorm.py +135 -0
- onnxruntime/quantization/matmul_4bits_quantizer.py +1480 -0
- onnxruntime/quantization/matmul_bnb4_quantizer.py +240 -0
- onnxruntime/quantization/onnx_model.py +580 -0
- onnxruntime/quantization/onnx_quantizer.py +1008 -0
- onnxruntime/quantization/operators/__init__.py +2 -0
- onnxruntime/quantization/operators/activation.py +119 -0
- onnxruntime/quantization/operators/argmax.py +18 -0
- onnxruntime/quantization/operators/attention.py +73 -0
- onnxruntime/quantization/operators/base_operator.py +26 -0
- onnxruntime/quantization/operators/binary_op.py +72 -0
- onnxruntime/quantization/operators/concat.py +62 -0
- onnxruntime/quantization/operators/conv.py +258 -0
- onnxruntime/quantization/operators/direct_q8.py +78 -0
- onnxruntime/quantization/operators/embed_layernorm.py +121 -0
- onnxruntime/quantization/operators/gather.py +64 -0
- onnxruntime/quantization/operators/gavgpool.py +62 -0
- onnxruntime/quantization/operators/gemm.py +166 -0
- onnxruntime/quantization/operators/lstm.py +117 -0
- onnxruntime/quantization/operators/matmul.py +231 -0
- onnxruntime/quantization/operators/maxpool.py +34 -0
- onnxruntime/quantization/operators/norm.py +40 -0
- onnxruntime/quantization/operators/pad.py +100 -0
- onnxruntime/quantization/operators/pooling.py +67 -0
- onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
- onnxruntime/quantization/operators/resize.py +34 -0
- onnxruntime/quantization/operators/softmax.py +74 -0
- onnxruntime/quantization/operators/split.py +63 -0
- onnxruntime/quantization/operators/where.py +87 -0
- onnxruntime/quantization/preprocess.py +141 -0
- onnxruntime/quantization/qdq_loss_debug.py +389 -0
- onnxruntime/quantization/qdq_quantizer.py +1187 -0
- onnxruntime/quantization/quant_utils.py +891 -0
- onnxruntime/quantization/quantize.py +748 -0
- onnxruntime/quantization/registry.py +106 -0
- onnxruntime/quantization/shape_inference.py +187 -0
- onnxruntime/quantization/tensor_quant_overrides.py +516 -0
- onnxruntime/tools/__init__.py +10 -0
- onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
- onnxruntime/tools/convert_onnx_models_to_ort.py +377 -0
- onnxruntime/tools/file_utils.py +46 -0
- onnxruntime/tools/logger.py +11 -0
- onnxruntime/tools/make_dynamic_shape_fixed.py +72 -0
- onnxruntime/tools/mobile_helpers/__init__.py +0 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +33 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
- onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
- onnxruntime/tools/mobile_helpers/usability_checker.py +739 -0
- onnxruntime/tools/offline_tuning.py +169 -0
- onnxruntime/tools/onnx_model_utils.py +413 -0
- onnxruntime/tools/onnx_randomizer.py +85 -0
- onnxruntime/tools/onnxruntime_test.py +164 -0
- onnxruntime/tools/optimize_onnx_model.py +55 -0
- onnxruntime/tools/ort_format_model/__init__.py +25 -0
- onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +663 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
- onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
- onnxruntime/tools/ort_format_model/types.py +84 -0
- onnxruntime/tools/ort_format_model/utils.py +62 -0
- onnxruntime/tools/pytorch_export_contrib_ops.py +108 -0
- onnxruntime/tools/pytorch_export_helpers.py +131 -0
- onnxruntime/tools/qdq_helpers/__init__.py +0 -0
- onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
- onnxruntime/tools/reduced_build_config_parser.py +202 -0
- onnxruntime/tools/symbolic_shape_infer.py +3016 -0
- onnxruntime/tools/update_onnx_opset.py +31 -0
- onnxruntime/transformers/__init__.py +8 -0
- onnxruntime/transformers/affinity_helper.py +40 -0
- onnxruntime/transformers/benchmark.py +944 -0
- onnxruntime/transformers/benchmark_helper.py +646 -0
- onnxruntime/transformers/bert_perf_test.py +634 -0
- onnxruntime/transformers/bert_test_data.py +642 -0
- onnxruntime/transformers/compare_bert_results.py +246 -0
- onnxruntime/transformers/constants.py +47 -0
- onnxruntime/transformers/convert_generation.py +3124 -0
- onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
- onnxruntime/transformers/convert_to_packing_mode.py +387 -0
- onnxruntime/transformers/dynamo_onnx_helper.py +104 -0
- onnxruntime/transformers/float16.py +501 -0
- onnxruntime/transformers/fusion_attention.py +1235 -0
- onnxruntime/transformers/fusion_attention_clip.py +257 -0
- onnxruntime/transformers/fusion_attention_sam2.py +534 -0
- onnxruntime/transformers/fusion_attention_unet.py +1304 -0
- onnxruntime/transformers/fusion_attention_vae.py +301 -0
- onnxruntime/transformers/fusion_bart_attention.py +640 -0
- onnxruntime/transformers/fusion_base.py +137 -0
- onnxruntime/transformers/fusion_bias_add.py +58 -0
- onnxruntime/transformers/fusion_biasgelu.py +66 -0
- onnxruntime/transformers/fusion_biassplitgelu.py +111 -0
- onnxruntime/transformers/fusion_conformer_attention.py +143 -0
- onnxruntime/transformers/fusion_embedlayer.py +811 -0
- onnxruntime/transformers/fusion_fastgelu.py +360 -0
- onnxruntime/transformers/fusion_gelu.py +259 -0
- onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
- onnxruntime/transformers/fusion_gemmfastgelu.py +122 -0
- onnxruntime/transformers/fusion_gpt_attention.py +546 -0
- onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
- onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
- onnxruntime/transformers/fusion_group_norm.py +179 -0
- onnxruntime/transformers/fusion_layernorm.py +465 -0
- onnxruntime/transformers/fusion_nhwc_conv.py +100 -0
- onnxruntime/transformers/fusion_options.py +340 -0
- onnxruntime/transformers/fusion_qordered_attention.py +421 -0
- onnxruntime/transformers/fusion_qordered_gelu.py +119 -0
- onnxruntime/transformers/fusion_qordered_layernorm.py +123 -0
- onnxruntime/transformers/fusion_qordered_matmul.py +217 -0
- onnxruntime/transformers/fusion_quickgelu.py +74 -0
- onnxruntime/transformers/fusion_reshape.py +173 -0
- onnxruntime/transformers/fusion_rotary_attention.py +1592 -0
- onnxruntime/transformers/fusion_shape.py +110 -0
- onnxruntime/transformers/fusion_simplified_layernorm.py +159 -0
- onnxruntime/transformers/fusion_skip_group_norm.py +255 -0
- onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
- onnxruntime/transformers/fusion_transpose.py +168 -0
- onnxruntime/transformers/fusion_utils.py +307 -0
- onnxruntime/transformers/huggingface_models.py +167 -0
- onnxruntime/transformers/import_utils.py +20 -0
- onnxruntime/transformers/io_binding_helper.py +442 -0
- onnxruntime/transformers/large_model_exporter.py +395 -0
- onnxruntime/transformers/machine_info.py +221 -0
- onnxruntime/transformers/metrics.py +164 -0
- onnxruntime/transformers/models/bart/__init__.py +12 -0
- onnxruntime/transformers/models/bart/export.py +98 -0
- onnxruntime/transformers/models/bert/__init__.py +12 -0
- onnxruntime/transformers/models/bert/eval_squad.py +329 -0
- onnxruntime/transformers/models/gpt2/__init__.py +12 -0
- onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
- onnxruntime/transformers/models/gpt2/convert_to_onnx.py +561 -0
- onnxruntime/transformers/models/gpt2/gpt2_helper.py +1032 -0
- onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
- onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
- onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
- onnxruntime/transformers/models/llama/__init__.py +12 -0
- onnxruntime/transformers/models/llama/benchmark.py +703 -0
- onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
- onnxruntime/transformers/models/llama/benchmark_e2e.py +606 -0
- onnxruntime/transformers/models/llama/convert_to_onnx.py +1027 -0
- onnxruntime/transformers/models/llama/dist_settings.py +57 -0
- onnxruntime/transformers/models/llama/llama_inputs.py +503 -0
- onnxruntime/transformers/models/llama/llama_parity.py +309 -0
- onnxruntime/transformers/models/llama/llama_torch.py +47 -0
- onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
- onnxruntime/transformers/models/longformer/__init__.py +12 -0
- onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
- onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
- onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
- onnxruntime/transformers/models/longformer/longformer_helper.py +77 -0
- onnxruntime/transformers/models/phi2/__init__.py +12 -0
- onnxruntime/transformers/models/phi2/convert_to_onnx.py +576 -0
- onnxruntime/transformers/models/phi2/inference_example.py +414 -0
- onnxruntime/transformers/models/sam2/__init__.py +12 -0
- onnxruntime/transformers/models/sam2/benchmark_sam2.py +625 -0
- onnxruntime/transformers/models/sam2/convert_to_onnx.py +260 -0
- onnxruntime/transformers/models/sam2/image_decoder.py +273 -0
- onnxruntime/transformers/models/sam2/image_encoder.py +186 -0
- onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
- onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
- onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
- onnxruntime/transformers/models/sam2/sam2_demo.py +322 -0
- onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +280 -0
- onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
- onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark.py +1429 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +102 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +268 -0
- onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1319 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1181 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder.py +296 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +388 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
- onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +350 -0
- onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
- onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
- onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
- onnxruntime/transformers/models/t5/__init__.py +12 -0
- onnxruntime/transformers/models/t5/convert_to_onnx.py +278 -0
- onnxruntime/transformers/models/t5/past_helper.py +150 -0
- onnxruntime/transformers/models/t5/t5_decoder.py +438 -0
- onnxruntime/transformers/models/t5/t5_encoder.py +171 -0
- onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +299 -0
- onnxruntime/transformers/models/t5/t5_helper.py +272 -0
- onnxruntime/transformers/models/whisper/__init__.py +12 -0
- onnxruntime/transformers/models/whisper/benchmark.py +610 -0
- onnxruntime/transformers/models/whisper/benchmark_all.py +528 -0
- onnxruntime/transformers/models/whisper/convert_to_onnx.py +536 -0
- onnxruntime/transformers/models/whisper/whisper_chain.py +329 -0
- onnxruntime/transformers/models/whisper/whisper_decoder.py +402 -0
- onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
- onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +306 -0
- onnxruntime/transformers/models/whisper/whisper_helper.py +524 -0
- onnxruntime/transformers/models/whisper/whisper_openai_helper.py +84 -0
- onnxruntime/transformers/onnx_exporter.py +717 -0
- onnxruntime/transformers/onnx_model.py +1569 -0
- onnxruntime/transformers/onnx_model_bart.py +142 -0
- onnxruntime/transformers/onnx_model_bert.py +481 -0
- onnxruntime/transformers/onnx_model_bert_keras.py +475 -0
- onnxruntime/transformers/onnx_model_bert_tf.py +589 -0
- onnxruntime/transformers/onnx_model_clip.py +40 -0
- onnxruntime/transformers/onnx_model_conformer.py +33 -0
- onnxruntime/transformers/onnx_model_gpt2.py +101 -0
- onnxruntime/transformers/onnx_model_phi.py +930 -0
- onnxruntime/transformers/onnx_model_sam2.py +138 -0
- onnxruntime/transformers/onnx_model_t5.py +791 -0
- onnxruntime/transformers/onnx_model_tnlr.py +227 -0
- onnxruntime/transformers/onnx_model_unet.py +259 -0
- onnxruntime/transformers/onnx_model_vae.py +43 -0
- onnxruntime/transformers/onnx_utils.py +55 -0
- onnxruntime/transformers/optimizer.py +612 -0
- onnxruntime/transformers/profiler.py +725 -0
- onnxruntime/transformers/quantize_helper.py +76 -0
- onnxruntime/transformers/shape_infer_helper.py +122 -0
- onnxruntime/transformers/shape_optimizer.py +401 -0
- onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
- onnxruntime_directml-1.20.0.dist-info/METADATA +187 -0
- onnxruntime_directml-1.20.0.dist-info/RECORD +305 -0
- onnxruntime_directml-1.20.0.dist-info/WHEEL +5 -0
- onnxruntime_directml-1.20.0.dist-info/entry_points.txt +2 -0
- onnxruntime_directml-1.20.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,413 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License. See License.txt in the project root for
|
|
4
|
+
# license information.
|
|
5
|
+
# --------------------------------------------------------------------------
|
|
6
|
+
|
|
7
|
+
# This script converts Longformer model from huggingface transformers 4.0 or later to ONNX.
|
|
8
|
+
# It translates LongformerSelfAttention to the LongformerAttention operator in ONNX Runtime.
|
|
9
|
+
#
|
|
10
|
+
# Before running this script, prepare a python environment in Linux with PyTorch 1.9.0 and other packages installed.
|
|
11
|
+
# Then run "python setup.py install" in ./torch_extensions directory. If your python version is not 3.8, you will need
|
|
12
|
+
# update this script with correct name of longformer_attention.cpython-*.so (search TODO below).
|
|
13
|
+
#
|
|
14
|
+
# It is tested in Ubuntu 18.04 with python 3.8, onnxruntime-gpu 1.11.0, PyTorch 1.9.0, transformers 4.18.0.
|
|
15
|
+
# Warning: Using PyTorch 1.10 or newer version might encounter issue in exporting, but they are fine for benchmarking.
|
|
16
|
+
#
|
|
17
|
+
# Example commands to export longformer base model in Linux:
|
|
18
|
+
# conda create -n longformer python=3.8
|
|
19
|
+
# conda activate longformer
|
|
20
|
+
# python3 -m pip install torch==1.9.0+cu111 torchvision==0.10.0+cu111 torchaudio==0.9.0 -f https://download.pytorch.org/whl/torch_stable.html
|
|
21
|
+
# python3 -m pip install coloredlogs flatbuffers numpy packaging sympy protobuf==3.20.1 onnx==1.12.0 transformers==4.18.0
|
|
22
|
+
# python3 -m pip install -i https://test.pypi.org/simple/ ort-nightly-gpu
|
|
23
|
+
# cd ./torch_extensions
|
|
24
|
+
# rm -rf build
|
|
25
|
+
# python setup.py install
|
|
26
|
+
# cd ..
|
|
27
|
+
# python convert_to_onnx.py --model longformer-base-4096 --precision fp16 --optimize_onnx
|
|
28
|
+
# python convert_to_onnx.py --model longformer-base-4096 --precision fp16 --optimize_onnx --no_merge_qkv
|
|
29
|
+
#
|
|
30
|
+
# GPU is not needed for this script. You can run it in CPU. For --optimize_onnx, you can use either onnxruntime or onnxruntime-gpu package.
|
|
31
|
+
#
|
|
32
|
+
# For inference of the onnx model, you will need onnxruntime-gpu 1.7.0 or newer version.
|
|
33
|
+
|
|
34
|
+
import argparse
|
|
35
|
+
import inspect
|
|
36
|
+
from pathlib import Path
|
|
37
|
+
|
|
38
|
+
import torch
|
|
39
|
+
import transformers
|
|
40
|
+
from longformer_helper import PRETRAINED_LONGFORMER_MODELS
|
|
41
|
+
from onnx import load_model
|
|
42
|
+
from onnx_model_bert import BertOnnxModel
|
|
43
|
+
from packaging import version
|
|
44
|
+
from torch.onnx import register_custom_op_symbolic
|
|
45
|
+
from torch.onnx.symbolic_helper import parse_args
|
|
46
|
+
from torch_onnx_export_helper import torch_onnx_export
|
|
47
|
+
from transformers import LongformerModel, LongformerSelfAttention
|
|
48
|
+
|
|
49
|
+
# Supports format 0 or 1
|
|
50
|
+
weight_bias_format = 0
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
@parse_args("v", "v", "v", "v", "v", "v", "v", "i", "i")
|
|
54
|
+
def my_longformer_attention(
|
|
55
|
+
g,
|
|
56
|
+
input,
|
|
57
|
+
weight,
|
|
58
|
+
bias,
|
|
59
|
+
mask,
|
|
60
|
+
global_weight,
|
|
61
|
+
global_bias,
|
|
62
|
+
global_mask,
|
|
63
|
+
num_heads,
|
|
64
|
+
window,
|
|
65
|
+
):
|
|
66
|
+
return g.op(
|
|
67
|
+
"com.microsoft::LongformerAttention",
|
|
68
|
+
input,
|
|
69
|
+
weight,
|
|
70
|
+
bias,
|
|
71
|
+
mask,
|
|
72
|
+
global_weight,
|
|
73
|
+
global_bias,
|
|
74
|
+
global_mask,
|
|
75
|
+
num_heads_i=num_heads,
|
|
76
|
+
window_i=window,
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
# namespace is onnxruntime which is registered in longformer_attention.cpp
|
|
81
|
+
register_custom_op_symbolic("onnxruntime::LongformerAttention", my_longformer_attention, 9)
|
|
82
|
+
|
|
83
|
+
# TODO: search the directory to find correct output filename of "python setup.py install" when python version is not 3.8
|
|
84
|
+
torch.ops.load_library(
|
|
85
|
+
r"./torch_extensions/build/lib.linux-x86_64-3.8/longformer_attention.cpython-38-x86_64-linux-gnu.so"
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def parse_arguments():
|
|
90
|
+
"""Parse arguments
|
|
91
|
+
|
|
92
|
+
Returns:
|
|
93
|
+
args: Namespace
|
|
94
|
+
"""
|
|
95
|
+
parser = argparse.ArgumentParser()
|
|
96
|
+
|
|
97
|
+
parser.add_argument(
|
|
98
|
+
"-m",
|
|
99
|
+
"--model",
|
|
100
|
+
required=False,
|
|
101
|
+
type=str,
|
|
102
|
+
default="longformer-base-4096",
|
|
103
|
+
help="Checkpoint directory or pre-trained model names in the list: "
|
|
104
|
+
+ ", ".join(PRETRAINED_LONGFORMER_MODELS.keys()),
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
parser.add_argument(
|
|
108
|
+
"--export_padding",
|
|
109
|
+
required=False,
|
|
110
|
+
action="store_true",
|
|
111
|
+
help="Export padding logic to ONNX graph. If not enabled, user need pad input so that sequence length is multiple of window size.",
|
|
112
|
+
)
|
|
113
|
+
parser.set_defaults(export_padding=False)
|
|
114
|
+
|
|
115
|
+
parser.add_argument(
|
|
116
|
+
"--no_merge_qkv",
|
|
117
|
+
required=False,
|
|
118
|
+
action="store_true",
|
|
119
|
+
help="Stack the weights of q, k and v on dimension 0 instead of dimension 1.",
|
|
120
|
+
)
|
|
121
|
+
parser.set_defaults(no_merge_qkv=False)
|
|
122
|
+
|
|
123
|
+
parser.add_argument(
|
|
124
|
+
"-o",
|
|
125
|
+
"--optimize_onnx",
|
|
126
|
+
required=False,
|
|
127
|
+
action="store_true",
|
|
128
|
+
help="Use optimizer.py to optimize onnx model.",
|
|
129
|
+
)
|
|
130
|
+
parser.set_defaults(optimize_onnx=False)
|
|
131
|
+
|
|
132
|
+
parser.add_argument(
|
|
133
|
+
"-p",
|
|
134
|
+
"--precision",
|
|
135
|
+
required=False,
|
|
136
|
+
type=str,
|
|
137
|
+
default="fp32",
|
|
138
|
+
choices=["fp32", "fp16"],
|
|
139
|
+
help="Precision of model to run: fp32 for full precision, fp16 for mixed precision",
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
args = parser.parse_args()
|
|
143
|
+
return args
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
# Create a dummy input for ONNX export.
|
|
147
|
+
def get_dummy_inputs(config, export_padding, device):
|
|
148
|
+
# When sequence length is multiple of windows size, there is no padding logic in ONNX graph
|
|
149
|
+
sequence_length = config.attention_window[0] + 1 if export_padding else config.attention_window[0]
|
|
150
|
+
|
|
151
|
+
# Create dummy inputs
|
|
152
|
+
input_ids = torch.arange(sequence_length).unsqueeze(0).to(device)
|
|
153
|
+
|
|
154
|
+
attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=device)
|
|
155
|
+
attention_mask[:, sequence_length - 1] = 0 # last token is masked
|
|
156
|
+
|
|
157
|
+
global_attention_mask = torch.zeros(input_ids.shape, dtype=torch.long, device=device)
|
|
158
|
+
global_attention_mask[:, 0] = 1 # first token is global token
|
|
159
|
+
|
|
160
|
+
return input_ids, attention_mask, global_attention_mask
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
# A new function to replace LongformerSelfAttention.forward
|
|
164
|
+
# For transformers 4.0.0
|
|
165
|
+
def my_longformer_self_attention_forward_4(
|
|
166
|
+
self,
|
|
167
|
+
hidden_states,
|
|
168
|
+
attention_mask=None,
|
|
169
|
+
is_index_masked=None,
|
|
170
|
+
is_index_global_attn=None,
|
|
171
|
+
is_global_attn=None,
|
|
172
|
+
):
|
|
173
|
+
global_mask = is_index_global_attn.int()
|
|
174
|
+
# The following check is based on the dummy inputs (only the first token is global).
|
|
175
|
+
assert (
|
|
176
|
+
len(global_mask.shape) == 2
|
|
177
|
+
and global_mask.shape[0] == 1
|
|
178
|
+
and global_mask.count_nonzero().item() == 1
|
|
179
|
+
and global_mask.tolist()[0][0] == 1
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
input_mask = is_index_masked.float()
|
|
183
|
+
# TODO: The filtering value may be -10000.0 or -inf. Check the huggingface implementation.
|
|
184
|
+
input_mask = input_mask.masked_fill(is_index_masked, -10000.0)
|
|
185
|
+
# Yet another way to generate input_mask = torch.masked_fill(attention_mask, is_index_global_attn, 0.0)
|
|
186
|
+
|
|
187
|
+
# TODO: add postprocessing of ONNX model to calculate based on graph input: input_mask = (attention_mask - 1) * 10000.0
|
|
188
|
+
# TODO: add postprocessing of ONNX model to use graph input directly: global_mask = global_attention_mask
|
|
189
|
+
|
|
190
|
+
# The following check is based on the dummy inputs (only the last token is masked).
|
|
191
|
+
assert (
|
|
192
|
+
len(input_mask.shape) == 2
|
|
193
|
+
and input_mask.shape[0] == 1
|
|
194
|
+
and input_mask.count_nonzero().item() == 1
|
|
195
|
+
and input_mask.tolist()[0][-1] == -10000.0
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
weight = torch.stack(
|
|
199
|
+
(
|
|
200
|
+
self.query.weight.transpose(0, 1),
|
|
201
|
+
self.key.weight.transpose(0, 1),
|
|
202
|
+
self.value.weight.transpose(0, 1),
|
|
203
|
+
),
|
|
204
|
+
dim=weight_bias_format,
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
if weight_bias_format == 1:
|
|
208
|
+
# shape is (hidden_size, 3*hidden_size) for format 1, otherwise (3, hidden_size, hidden_size) by default
|
|
209
|
+
weight = weight.reshape(self.embed_dim, 3 * self.embed_dim)
|
|
210
|
+
|
|
211
|
+
global_weight = torch.stack(
|
|
212
|
+
(
|
|
213
|
+
self.query_global.weight.transpose(0, 1),
|
|
214
|
+
self.key_global.weight.transpose(0, 1),
|
|
215
|
+
self.value_global.weight.transpose(0, 1),
|
|
216
|
+
),
|
|
217
|
+
dim=weight_bias_format,
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
if weight_bias_format == 1:
|
|
221
|
+
global_weight = global_weight.reshape(self.embed_dim, 3 * self.embed_dim)
|
|
222
|
+
|
|
223
|
+
if weight_bias_format == 1:
|
|
224
|
+
bias = torch.stack((self.query.bias, self.key.bias, self.value.bias), dim=0)
|
|
225
|
+
bias = bias.reshape(3 * self.embed_dim)
|
|
226
|
+
global_bias = torch.stack((self.query_global.bias, self.key_global.bias, self.value_global.bias), dim=0)
|
|
227
|
+
global_bias = global_bias.reshape(3 * self.embed_dim)
|
|
228
|
+
else:
|
|
229
|
+
bias = torch.stack(
|
|
230
|
+
(self.query.bias, self.key.bias, self.value.bias, self.key_global.bias, self.value_global.bias), dim=0
|
|
231
|
+
)
|
|
232
|
+
bias = bias.reshape(5 * self.embed_dim)
|
|
233
|
+
global_bias = self.query_global.bias
|
|
234
|
+
global_bias = global_bias.reshape(1 * self.embed_dim)
|
|
235
|
+
|
|
236
|
+
attn_output = torch.ops.onnxruntime.LongformerAttention(
|
|
237
|
+
hidden_states,
|
|
238
|
+
weight,
|
|
239
|
+
bias,
|
|
240
|
+
input_mask,
|
|
241
|
+
global_weight,
|
|
242
|
+
global_bias,
|
|
243
|
+
global_mask,
|
|
244
|
+
self.num_heads,
|
|
245
|
+
self.one_sided_attn_window_size,
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
assert attn_output.size() == hidden_states.size(), "Unexpected size"
|
|
249
|
+
|
|
250
|
+
outputs = (attn_output,)
|
|
251
|
+
return outputs
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
# For transformers 4.3.0
|
|
255
|
+
def my_longformer_self_attention_forward_4_3(
|
|
256
|
+
self,
|
|
257
|
+
hidden_states,
|
|
258
|
+
attention_mask=None,
|
|
259
|
+
is_index_masked=None,
|
|
260
|
+
is_index_global_attn=None,
|
|
261
|
+
is_global_attn=None,
|
|
262
|
+
output_attentions=False,
|
|
263
|
+
):
|
|
264
|
+
assert output_attentions is False
|
|
265
|
+
return my_longformer_self_attention_forward_4(
|
|
266
|
+
self,
|
|
267
|
+
hidden_states,
|
|
268
|
+
attention_mask,
|
|
269
|
+
is_index_masked,
|
|
270
|
+
is_index_global_attn,
|
|
271
|
+
is_global_attn,
|
|
272
|
+
)
|
|
273
|
+
|
|
274
|
+
|
|
275
|
+
# For transformers 4.3.2 or later versions
|
|
276
|
+
def my_longformer_self_attention_forward_4_3_2(
|
|
277
|
+
self,
|
|
278
|
+
hidden_states,
|
|
279
|
+
attention_mask=None,
|
|
280
|
+
layer_head_mask=None,
|
|
281
|
+
is_index_masked=None,
|
|
282
|
+
is_index_global_attn=None,
|
|
283
|
+
is_global_attn=None,
|
|
284
|
+
output_attentions=False,
|
|
285
|
+
):
|
|
286
|
+
assert output_attentions is False
|
|
287
|
+
assert layer_head_mask is None
|
|
288
|
+
return my_longformer_self_attention_forward_4(
|
|
289
|
+
self,
|
|
290
|
+
hidden_states,
|
|
291
|
+
attention_mask,
|
|
292
|
+
is_index_masked,
|
|
293
|
+
is_index_global_attn,
|
|
294
|
+
is_global_attn,
|
|
295
|
+
)
|
|
296
|
+
|
|
297
|
+
|
|
298
|
+
def export_longformer(model: LongformerModel, onnx_model_path: str, export_padding: bool):
|
|
299
|
+
"""Export longformer model to ONNX
|
|
300
|
+
|
|
301
|
+
Args:
|
|
302
|
+
model (LongformerModel): longformer model
|
|
303
|
+
onnx_model_path (str): output onnx path
|
|
304
|
+
export_padding (bool): whether export padding logic to ONNX so that input string can be any length.
|
|
305
|
+
|
|
306
|
+
Raises:
|
|
307
|
+
RuntimeError: This tool requires transformers 4.0.0 or later.
|
|
308
|
+
RuntimeError: LongformerSelfAttention.forward arguments are different.
|
|
309
|
+
"""
|
|
310
|
+
input_ids, attention_mask, global_attention_mask = get_dummy_inputs(
|
|
311
|
+
model.config, export_padding, device=torch.device("cpu")
|
|
312
|
+
)
|
|
313
|
+
|
|
314
|
+
_ = model(
|
|
315
|
+
input_ids,
|
|
316
|
+
attention_mask=attention_mask,
|
|
317
|
+
global_attention_mask=global_attention_mask,
|
|
318
|
+
)
|
|
319
|
+
|
|
320
|
+
if version.parse(transformers.__version__) < version.parse("4.0.0"):
|
|
321
|
+
raise RuntimeError("This tool requires transformers 4.0.0 or later.")
|
|
322
|
+
|
|
323
|
+
# Here we replace LongformerSelfAttention.forward using our implementation for exporting ONNX model
|
|
324
|
+
key = " ".join(inspect.getfullargspec(LongformerSelfAttention.forward).args)
|
|
325
|
+
args_to_func = {
|
|
326
|
+
"self hidden_states attention_mask layer_head_mask is_index_masked is_index_global_attn is_global_attn output_attentions": my_longformer_self_attention_forward_4_3_2,
|
|
327
|
+
"self hidden_states attention_mask is_index_masked is_index_global_attn is_global_attn output_attentions": my_longformer_self_attention_forward_4_3,
|
|
328
|
+
"self hidden_states attention_mask is_index_masked is_index_global_attn is_global_attn": my_longformer_self_attention_forward_4,
|
|
329
|
+
}
|
|
330
|
+
|
|
331
|
+
if key not in args_to_func:
|
|
332
|
+
print(
|
|
333
|
+
"Current arguments",
|
|
334
|
+
inspect.getfullargspec(LongformerSelfAttention.forward).args,
|
|
335
|
+
)
|
|
336
|
+
raise RuntimeError(
|
|
337
|
+
"LongformerSelfAttention.forward arguments are different. Please install supported version (like transformers 4.3.0)."
|
|
338
|
+
)
|
|
339
|
+
|
|
340
|
+
# Store for restoring later
|
|
341
|
+
original_forward = LongformerSelfAttention.forward
|
|
342
|
+
|
|
343
|
+
LongformerSelfAttention.forward = args_to_func[key]
|
|
344
|
+
|
|
345
|
+
example_inputs = (input_ids, attention_mask, global_attention_mask)
|
|
346
|
+
|
|
347
|
+
Path(onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
|
|
348
|
+
|
|
349
|
+
torch_onnx_export(
|
|
350
|
+
model,
|
|
351
|
+
example_inputs,
|
|
352
|
+
onnx_model_path,
|
|
353
|
+
opset_version=12,
|
|
354
|
+
input_names=["input_ids", "attention_mask", "global_attention_mask"],
|
|
355
|
+
output_names=["last_state", "pooler"],
|
|
356
|
+
dynamic_axes={
|
|
357
|
+
"input_ids": {0: "batch_size", 1: "sequence_length"},
|
|
358
|
+
"attention_mask": {0: "batch_size", 1: "sequence_length"},
|
|
359
|
+
"global_attention_mask": {0: "batch_size", 1: "sequence_length"},
|
|
360
|
+
"last_state": {0: "batch_size", 1: "sequence_length"},
|
|
361
|
+
"pooler": {0: "batch_size", 1: "sequence_length"},
|
|
362
|
+
},
|
|
363
|
+
custom_opsets={"com.microsoft": 1},
|
|
364
|
+
)
|
|
365
|
+
print(f"ONNX model exported to {onnx_model_path}")
|
|
366
|
+
|
|
367
|
+
# Restore original implementation:
|
|
368
|
+
LongformerSelfAttention.forward = original_forward
|
|
369
|
+
|
|
370
|
+
|
|
371
|
+
def optimize_longformer(onnx_model_path: str, fp32_model_path: str, fp16_model_path=None):
|
|
372
|
+
"""Optimize longformer onnx model
|
|
373
|
+
|
|
374
|
+
Args:
|
|
375
|
+
onnx_model_path (str): path of original ONNX model.
|
|
376
|
+
fp32_model_path (str): path of optimized fp32 model.
|
|
377
|
+
fp16_model_path (str, optional): path of optimized fp16 model. Defaults to None.
|
|
378
|
+
"""
|
|
379
|
+
model = load_model(onnx_model_path, format=None, load_external_data=True)
|
|
380
|
+
optimizer = BertOnnxModel(model)
|
|
381
|
+
optimizer.optimize()
|
|
382
|
+
|
|
383
|
+
use_external_data_format = False
|
|
384
|
+
if fp32_model_path:
|
|
385
|
+
optimizer.save_model_to_file(fp32_model_path, use_external_data_format)
|
|
386
|
+
print(f"optimized fp32 model saved to {fp32_model_path}")
|
|
387
|
+
|
|
388
|
+
if fp16_model_path:
|
|
389
|
+
optimizer.convert_float_to_float16(keep_io_types=True)
|
|
390
|
+
optimizer.save_model_to_file(fp16_model_path, use_external_data_format)
|
|
391
|
+
print(f"optimized fp16 model saved to {fp16_model_path}")
|
|
392
|
+
|
|
393
|
+
|
|
394
|
+
def main(args):
|
|
395
|
+
model_name = args.model
|
|
396
|
+
onnx_model_path = model_name + ".onnx"
|
|
397
|
+
|
|
398
|
+
global weight_bias_format # noqa: PLW0603
|
|
399
|
+
weight_bias_format = 0 if args.no_merge_qkv else 1
|
|
400
|
+
|
|
401
|
+
model = LongformerModel.from_pretrained(PRETRAINED_LONGFORMER_MODELS[model_name])
|
|
402
|
+
|
|
403
|
+
export_longformer(model, onnx_model_path, args.export_padding)
|
|
404
|
+
|
|
405
|
+
if args.optimize_onnx or args.precision != "fp32":
|
|
406
|
+
fp32_model_path = model_name + f"_f{weight_bias_format}" + "_fp32.onnx"
|
|
407
|
+
fp16_model_path = model_name + f"_f{weight_bias_format}" + "_fp16.onnx" if args.precision == "fp16" else None
|
|
408
|
+
optimize_longformer(onnx_model_path, fp32_model_path, fp16_model_path)
|
|
409
|
+
|
|
410
|
+
|
|
411
|
+
if __name__ == "__main__":
|
|
412
|
+
args = parse_arguments()
|
|
413
|
+
main(args)
|