onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxruntime/LICENSE +21 -0
- onnxruntime/Privacy.md +21 -0
- onnxruntime/ThirdPartyNotices.txt +6121 -0
- onnxruntime/__init__.py +418 -0
- onnxruntime/backend/__init__.py +6 -0
- onnxruntime/backend/backend.py +175 -0
- onnxruntime/backend/backend_rep.py +52 -0
- onnxruntime/capi/DirectML.dll +0 -0
- onnxruntime/capi/__init__.py +4 -0
- onnxruntime/capi/_ld_preload.py +7 -0
- onnxruntime/capi/_pybind_state.py +33 -0
- onnxruntime/capi/build_and_package_info.py +2 -0
- onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
- onnxruntime/capi/onnxruntime.dll +0 -0
- onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
- onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
- onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
- onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
- onnxruntime/capi/onnxruntime_validation.py +154 -0
- onnxruntime/capi/version_info.py +2 -0
- onnxruntime/datasets/__init__.py +18 -0
- onnxruntime/datasets/logreg_iris.onnx +0 -0
- onnxruntime/datasets/mul_1.onnx +0 -0
- onnxruntime/datasets/sigmoid.onnx +13 -0
- onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
- onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
- onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
- onnxruntime/quantization/__init__.py +19 -0
- onnxruntime/quantization/base_quantizer.py +529 -0
- onnxruntime/quantization/calibrate.py +1267 -0
- onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
- onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
- onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
- onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
- onnxruntime/quantization/fusions/__init__.py +4 -0
- onnxruntime/quantization/fusions/fusion.py +311 -0
- onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
- onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
- onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
- onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
- onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
- onnxruntime/quantization/neural_compressor/__init__.py +1 -0
- onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
- onnxruntime/quantization/neural_compressor/util.py +80 -0
- onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
- onnxruntime/quantization/onnx_model.py +600 -0
- onnxruntime/quantization/onnx_quantizer.py +1163 -0
- onnxruntime/quantization/operators/__init__.py +2 -0
- onnxruntime/quantization/operators/activation.py +119 -0
- onnxruntime/quantization/operators/argmax.py +18 -0
- onnxruntime/quantization/operators/attention.py +73 -0
- onnxruntime/quantization/operators/base_operator.py +26 -0
- onnxruntime/quantization/operators/binary_op.py +72 -0
- onnxruntime/quantization/operators/concat.py +62 -0
- onnxruntime/quantization/operators/conv.py +260 -0
- onnxruntime/quantization/operators/direct_q8.py +78 -0
- onnxruntime/quantization/operators/embed_layernorm.py +121 -0
- onnxruntime/quantization/operators/gather.py +64 -0
- onnxruntime/quantization/operators/gavgpool.py +62 -0
- onnxruntime/quantization/operators/gemm.py +172 -0
- onnxruntime/quantization/operators/lstm.py +121 -0
- onnxruntime/quantization/operators/matmul.py +231 -0
- onnxruntime/quantization/operators/maxpool.py +34 -0
- onnxruntime/quantization/operators/norm.py +40 -0
- onnxruntime/quantization/operators/pad.py +172 -0
- onnxruntime/quantization/operators/pooling.py +67 -0
- onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
- onnxruntime/quantization/operators/resize.py +34 -0
- onnxruntime/quantization/operators/softmax.py +74 -0
- onnxruntime/quantization/operators/split.py +63 -0
- onnxruntime/quantization/operators/where.py +87 -0
- onnxruntime/quantization/preprocess.py +141 -0
- onnxruntime/quantization/qdq_loss_debug.py +389 -0
- onnxruntime/quantization/qdq_quantizer.py +1477 -0
- onnxruntime/quantization/quant_utils.py +1051 -0
- onnxruntime/quantization/quantize.py +953 -0
- onnxruntime/quantization/registry.py +110 -0
- onnxruntime/quantization/shape_inference.py +204 -0
- onnxruntime/quantization/static_quantize_runner.py +256 -0
- onnxruntime/quantization/tensor_quant_overrides.py +520 -0
- onnxruntime/tools/__init__.py +10 -0
- onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
- onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
- onnxruntime/tools/file_utils.py +47 -0
- onnxruntime/tools/logger.py +11 -0
- onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
- onnxruntime/tools/mobile_helpers/__init__.py +0 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
- onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
- onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
- onnxruntime/tools/offline_tuning.py +169 -0
- onnxruntime/tools/onnx_model_utils.py +416 -0
- onnxruntime/tools/onnx_randomizer.py +85 -0
- onnxruntime/tools/onnxruntime_test.py +164 -0
- onnxruntime/tools/optimize_onnx_model.py +56 -0
- onnxruntime/tools/ort_format_model/__init__.py +27 -0
- onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
- onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
- onnxruntime/tools/ort_format_model/types.py +85 -0
- onnxruntime/tools/ort_format_model/utils.py +61 -0
- onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
- onnxruntime/tools/pytorch_export_helpers.py +131 -0
- onnxruntime/tools/qdq_helpers/__init__.py +0 -0
- onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
- onnxruntime/tools/qnn/add_trans_cast.py +292 -0
- onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
- onnxruntime/tools/qnn/preprocess.py +165 -0
- onnxruntime/tools/reduced_build_config_parser.py +203 -0
- onnxruntime/tools/remove_initializer_from_input.py +37 -0
- onnxruntime/tools/symbolic_shape_infer.py +3094 -0
- onnxruntime/tools/update_onnx_opset.py +31 -0
- onnxruntime/transformers/__init__.py +8 -0
- onnxruntime/transformers/affinity_helper.py +40 -0
- onnxruntime/transformers/benchmark.py +942 -0
- onnxruntime/transformers/benchmark_helper.py +643 -0
- onnxruntime/transformers/bert_perf_test.py +629 -0
- onnxruntime/transformers/bert_test_data.py +641 -0
- onnxruntime/transformers/compare_bert_results.py +256 -0
- onnxruntime/transformers/constants.py +47 -0
- onnxruntime/transformers/convert_generation.py +3605 -0
- onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
- onnxruntime/transformers/convert_to_packing_mode.py +385 -0
- onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
- onnxruntime/transformers/float16.py +501 -0
- onnxruntime/transformers/fusion_attention.py +1189 -0
- onnxruntime/transformers/fusion_attention_clip.py +340 -0
- onnxruntime/transformers/fusion_attention_sam2.py +533 -0
- onnxruntime/transformers/fusion_attention_unet.py +1307 -0
- onnxruntime/transformers/fusion_attention_vae.py +300 -0
- onnxruntime/transformers/fusion_bart_attention.py +435 -0
- onnxruntime/transformers/fusion_base.py +141 -0
- onnxruntime/transformers/fusion_bias_add.py +57 -0
- onnxruntime/transformers/fusion_biasgelu.py +66 -0
- onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
- onnxruntime/transformers/fusion_conformer_attention.py +222 -0
- onnxruntime/transformers/fusion_constant_fold.py +144 -0
- onnxruntime/transformers/fusion_embedlayer.py +810 -0
- onnxruntime/transformers/fusion_fastgelu.py +492 -0
- onnxruntime/transformers/fusion_gelu.py +258 -0
- onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
- onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
- onnxruntime/transformers/fusion_gpt_attention.py +546 -0
- onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
- onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
- onnxruntime/transformers/fusion_group_norm.py +180 -0
- onnxruntime/transformers/fusion_layernorm.py +489 -0
- onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
- onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
- onnxruntime/transformers/fusion_options.py +340 -0
- onnxruntime/transformers/fusion_qordered_attention.py +420 -0
- onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
- onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
- onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
- onnxruntime/transformers/fusion_quickgelu.py +74 -0
- onnxruntime/transformers/fusion_reshape.py +173 -0
- onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
- onnxruntime/transformers/fusion_shape.py +109 -0
- onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
- onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
- onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
- onnxruntime/transformers/fusion_transpose.py +167 -0
- onnxruntime/transformers/fusion_utils.py +321 -0
- onnxruntime/transformers/huggingface_models.py +74 -0
- onnxruntime/transformers/import_utils.py +20 -0
- onnxruntime/transformers/io_binding_helper.py +487 -0
- onnxruntime/transformers/large_model_exporter.py +395 -0
- onnxruntime/transformers/machine_info.py +230 -0
- onnxruntime/transformers/metrics.py +163 -0
- onnxruntime/transformers/models/bart/__init__.py +12 -0
- onnxruntime/transformers/models/bart/export.py +98 -0
- onnxruntime/transformers/models/bert/__init__.py +12 -0
- onnxruntime/transformers/models/bert/eval_squad.py +329 -0
- onnxruntime/transformers/models/gpt2/__init__.py +12 -0
- onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
- onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
- onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
- onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
- onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
- onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
- onnxruntime/transformers/models/llama/__init__.py +12 -0
- onnxruntime/transformers/models/llama/benchmark.py +700 -0
- onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
- onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
- onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
- onnxruntime/transformers/models/llama/dist_settings.py +57 -0
- onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
- onnxruntime/transformers/models/llama/llama_parity.py +343 -0
- onnxruntime/transformers/models/llama/llama_torch.py +47 -0
- onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
- onnxruntime/transformers/models/longformer/__init__.py +12 -0
- onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
- onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
- onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
- onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
- onnxruntime/transformers/models/phi2/__init__.py +12 -0
- onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
- onnxruntime/transformers/models/phi2/inference_example.py +414 -0
- onnxruntime/transformers/models/sam2/__init__.py +12 -0
- onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
- onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
- onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
- onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
- onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
- onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
- onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
- onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
- onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
- onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
- onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
- onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
- onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
- onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
- onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
- onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
- onnxruntime/transformers/models/t5/__init__.py +12 -0
- onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
- onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
- onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
- onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
- onnxruntime/transformers/models/t5/t5_helper.py +302 -0
- onnxruntime/transformers/models/whisper/__init__.py +12 -0
- onnxruntime/transformers/models/whisper/benchmark.py +585 -0
- onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
- onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
- onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
- onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
- onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
- onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
- onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
- onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
- onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
- onnxruntime/transformers/onnx_exporter.py +719 -0
- onnxruntime/transformers/onnx_model.py +1636 -0
- onnxruntime/transformers/onnx_model_bart.py +141 -0
- onnxruntime/transformers/onnx_model_bert.py +488 -0
- onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
- onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
- onnxruntime/transformers/onnx_model_clip.py +42 -0
- onnxruntime/transformers/onnx_model_conformer.py +32 -0
- onnxruntime/transformers/onnx_model_gpt2.py +101 -0
- onnxruntime/transformers/onnx_model_mmdit.py +112 -0
- onnxruntime/transformers/onnx_model_phi.py +929 -0
- onnxruntime/transformers/onnx_model_sam2.py +137 -0
- onnxruntime/transformers/onnx_model_t5.py +985 -0
- onnxruntime/transformers/onnx_model_tnlr.py +226 -0
- onnxruntime/transformers/onnx_model_unet.py +258 -0
- onnxruntime/transformers/onnx_model_vae.py +42 -0
- onnxruntime/transformers/onnx_utils.py +55 -0
- onnxruntime/transformers/optimizer.py +620 -0
- onnxruntime/transformers/past_helper.py +149 -0
- onnxruntime/transformers/profile_result_processor.py +358 -0
- onnxruntime/transformers/profiler.py +434 -0
- onnxruntime/transformers/quantize_helper.py +76 -0
- onnxruntime/transformers/shape_infer_helper.py +121 -0
- onnxruntime/transformers/shape_optimizer.py +400 -0
- onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
- onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
- onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
- onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
- onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
- onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,821 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License. See License.txt in the project root for
|
|
4
|
+
# license information.
|
|
5
|
+
# --------------------------------------------------------------------------
|
|
6
|
+
#
|
|
7
|
+
# This script run benchmark of latency or peak memory usage of Longformer model inference.
|
|
8
|
+
# Please run convert_to_onnx.py to get onnx model before running benchmark.
|
|
9
|
+
#
|
|
10
|
+
# It is tested with python 3.8, onnxruntime-gpu 1.11.0, PyTorch 1.11.0, transformers 4.18.0, CUDA 11.3 like:
|
|
11
|
+
# conda create -n gpu_env python=3.8
|
|
12
|
+
# conda activate gpu_env
|
|
13
|
+
# pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113
|
|
14
|
+
# pip3 install onnx transformers onnxruntime-gpu numpy sympy psutil py3nvml
|
|
15
|
+
# python benchmark_longformer.py
|
|
16
|
+
#
|
|
17
|
+
# When there is no parameter, pre-defined tests will run on the longformer-base-4096 model.
|
|
18
|
+
|
|
19
|
+
# Benchmark the latency:
|
|
20
|
+
# python benchmark_longformer.py --model longformer-base-4096 --batch_sizes 1 --sequence_lengths 512 1024 2048 4096 \
|
|
21
|
+
# --global_lengths 8 --onnx ./longformer-base-4096_fp16.onnx -t 100
|
|
22
|
+
#
|
|
23
|
+
# Benchmark GPU peak memory:
|
|
24
|
+
# export ORT_LONGFORMER_COMPACT_MEMORY=0
|
|
25
|
+
# python benchmark_longformer.py --model longformer-base-4096 --batch_sizes 1 --sequence_lengths 4096 \
|
|
26
|
+
# --global_lengths 8 --onnx ./longformer-base-4096_fp32.onnx --memory -t 10 --engine onnxruntime
|
|
27
|
+
# export ORT_LONGFORMER_COMPACT_MEMORY=1
|
|
28
|
+
# python benchmark_longformer.py --model longformer-base-4096 --batch_sizes 1 --sequence_lengths 4096 \
|
|
29
|
+
# --global_lengths 8 --onnx ./longformer-base-4096_fp32.onnx --memory -t 10 --engine onnxruntime
|
|
30
|
+
#
|
|
31
|
+
# By default, compact memory kernel is enabled. To disable it, set environment variable ORT_LONGFORMER_COMPACT_MEMORY=0.
|
|
32
|
+
|
|
33
|
+
import argparse
|
|
34
|
+
import csv
|
|
35
|
+
import logging
|
|
36
|
+
import math
|
|
37
|
+
import os
|
|
38
|
+
import re
|
|
39
|
+
import sys
|
|
40
|
+
import timeit
|
|
41
|
+
import traceback
|
|
42
|
+
from concurrent.futures import ProcessPoolExecutor
|
|
43
|
+
from datetime import datetime
|
|
44
|
+
from typing import Any
|
|
45
|
+
|
|
46
|
+
import benchmark_helper
|
|
47
|
+
import numpy as np
|
|
48
|
+
import torch
|
|
49
|
+
from longformer_helper import PRETRAINED_LONGFORMER_MODELS, LongformerHelper, LongformerInputs
|
|
50
|
+
from transformers import LongformerModel
|
|
51
|
+
|
|
52
|
+
import onnxruntime
|
|
53
|
+
|
|
54
|
+
logger = logging.getLogger("")
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def test_torch_latency(
|
|
58
|
+
device,
|
|
59
|
+
model,
|
|
60
|
+
model_name,
|
|
61
|
+
batch_sizes,
|
|
62
|
+
sequence_lengths,
|
|
63
|
+
global_lengths,
|
|
64
|
+
test_times,
|
|
65
|
+
num_threads,
|
|
66
|
+
) -> list[dict[str, Any]]:
|
|
67
|
+
if num_threads > 0:
|
|
68
|
+
torch.set_num_threads(num_threads)
|
|
69
|
+
|
|
70
|
+
results = []
|
|
71
|
+
for batch_size in batch_sizes:
|
|
72
|
+
for sequence_length in sequence_lengths:
|
|
73
|
+
for global_length in global_lengths:
|
|
74
|
+
logger.info(f"batch_size={batch_size} sequence_length={sequence_length} global_length={global_length}")
|
|
75
|
+
inputs: LongformerInputs = LongformerHelper.get_dummy_inputs(
|
|
76
|
+
batch_size, sequence_length, global_length, device
|
|
77
|
+
)
|
|
78
|
+
input_list = inputs.to_list()
|
|
79
|
+
|
|
80
|
+
_ = model(*input_list)
|
|
81
|
+
runtimes = timeit.repeat(lambda: model(*input_list), repeat=test_times, number=1) # noqa: B023
|
|
82
|
+
result = {
|
|
83
|
+
"engine": "torch", # TODO: test torchscript
|
|
84
|
+
"version": torch.__version__,
|
|
85
|
+
"device": "cuda",
|
|
86
|
+
"optimizer": "",
|
|
87
|
+
"precision": "fp32",
|
|
88
|
+
"io_binding": "",
|
|
89
|
+
"model_name": model_name,
|
|
90
|
+
"description": model_name + " [torch]",
|
|
91
|
+
"inputs": 3,
|
|
92
|
+
"threads": num_threads,
|
|
93
|
+
"batch_size": batch_size,
|
|
94
|
+
"sequence_length": sequence_length,
|
|
95
|
+
"global_length": global_length,
|
|
96
|
+
"datetime": str(datetime.now()),
|
|
97
|
+
"memory": "NA",
|
|
98
|
+
"diff_max": 0,
|
|
99
|
+
"diff_90_percentile": 0,
|
|
100
|
+
"diff_95_percentile": 0,
|
|
101
|
+
"diff_99_percentile": 0,
|
|
102
|
+
"use_compact_memory": "NA",
|
|
103
|
+
}
|
|
104
|
+
result.update(benchmark_helper.get_latency_result(runtimes, batch_size))
|
|
105
|
+
logger.info("%s", result)
|
|
106
|
+
results.append(result)
|
|
107
|
+
return results
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def test_parity(device, model, ort_session, batch_size, sequence_length, global_length, verbose=True):
|
|
111
|
+
parameters = f"batch_size={batch_size} sequence_length={sequence_length} global_length={global_length}"
|
|
112
|
+
logger.info(f"Comparing Torch and ORT outputs for {parameters}...")
|
|
113
|
+
dummy_inputs: LongformerInputs = LongformerHelper.get_dummy_inputs(
|
|
114
|
+
batch_size, sequence_length, global_length, device
|
|
115
|
+
)
|
|
116
|
+
ort_inputs = dummy_inputs.get_ort_inputs()
|
|
117
|
+
ort_outputs = ort_session.run(None, ort_inputs)
|
|
118
|
+
input_list = dummy_inputs.to_list()
|
|
119
|
+
torch_outputs = model(*input_list)
|
|
120
|
+
max_diff = np.amax(torch_outputs[0].cpu().numpy() - ort_outputs[0])
|
|
121
|
+
logger.info(f"last_state max diff = {max_diff}")
|
|
122
|
+
if verbose and (math.isnan(max_diff) or max_diff > 0.001):
|
|
123
|
+
print("torch last_state:", torch_outputs[0])
|
|
124
|
+
print("ort last_state:", ort_outputs[0])
|
|
125
|
+
return float(max_diff)
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def test_ort_latency(
|
|
129
|
+
device,
|
|
130
|
+
model,
|
|
131
|
+
model_name,
|
|
132
|
+
description,
|
|
133
|
+
ort_session,
|
|
134
|
+
batch_sizes,
|
|
135
|
+
sequence_lengths,
|
|
136
|
+
global_lengths,
|
|
137
|
+
test_times,
|
|
138
|
+
num_threads,
|
|
139
|
+
optimizer=False,
|
|
140
|
+
precision="fp32",
|
|
141
|
+
disable_io_binding=False,
|
|
142
|
+
verbose=True,
|
|
143
|
+
use_compact_memory=False,
|
|
144
|
+
use_half4=False,
|
|
145
|
+
disable_parity=False,
|
|
146
|
+
) -> list[dict[str, Any]]:
|
|
147
|
+
results = []
|
|
148
|
+
for batch_size in batch_sizes:
|
|
149
|
+
for sequence_length in sequence_lengths:
|
|
150
|
+
for global_length in global_lengths:
|
|
151
|
+
assert global_length <= model.config.attention_window[0], (
|
|
152
|
+
"Limitation of current implementation: number of global token <= attention_window"
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
logger.info(
|
|
156
|
+
f"Testing batch_size={batch_size} sequence_length={sequence_length} global_length={global_length} "
|
|
157
|
+
f"optimizer={optimizer}, precision={precision} io_binding={not disable_io_binding}..."
|
|
158
|
+
)
|
|
159
|
+
dummy_inputs: LongformerInputs = LongformerHelper.get_dummy_inputs(
|
|
160
|
+
batch_size, sequence_length, global_length, device
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
# Run OnnxRuntime
|
|
164
|
+
ort_inputs = dummy_inputs.get_ort_inputs()
|
|
165
|
+
|
|
166
|
+
if verbose:
|
|
167
|
+
print(ort_inputs)
|
|
168
|
+
|
|
169
|
+
# run one query for warm up
|
|
170
|
+
ort_outputs = ort_session.run(None, ort_inputs)
|
|
171
|
+
|
|
172
|
+
result_template = {
|
|
173
|
+
"model_name": model_name,
|
|
174
|
+
"description": description,
|
|
175
|
+
"inputs": 3,
|
|
176
|
+
"engine": "OnnxRuntime",
|
|
177
|
+
"version": str(onnxruntime.__version__),
|
|
178
|
+
"device": "cuda",
|
|
179
|
+
"precision": str(precision),
|
|
180
|
+
"optimizer": int(optimizer),
|
|
181
|
+
"threads": int(num_threads),
|
|
182
|
+
"batch_size": int(batch_size),
|
|
183
|
+
"sequence_length": int(sequence_length),
|
|
184
|
+
"global_length": int(global_length),
|
|
185
|
+
"test_times": int(test_times),
|
|
186
|
+
"datetime": str(datetime.now()),
|
|
187
|
+
"memory": "",
|
|
188
|
+
"diff_max": None,
|
|
189
|
+
"diff_90_percentile": None,
|
|
190
|
+
"diff_95_percentile": None,
|
|
191
|
+
"diff_99_percentile": None,
|
|
192
|
+
"use_compact_memory": use_compact_memory,
|
|
193
|
+
"use_half4": use_half4,
|
|
194
|
+
}
|
|
195
|
+
|
|
196
|
+
if not disable_io_binding:
|
|
197
|
+
max_last_state_size = max(batch_sizes) * max(sequence_lengths) * model.config.hidden_size
|
|
198
|
+
max_pooler_size = max(batch_sizes) * max(sequence_lengths)
|
|
199
|
+
result = benchmark_helper.inference_ort_with_io_binding(
|
|
200
|
+
ort_session,
|
|
201
|
+
ort_inputs,
|
|
202
|
+
result_template=result_template,
|
|
203
|
+
repeat_times=test_times,
|
|
204
|
+
ort_output_names=["last_state", "pooler"],
|
|
205
|
+
ort_outputs=ort_outputs,
|
|
206
|
+
output_buffers=[],
|
|
207
|
+
output_buffer_max_sizes=[max_last_state_size, max_pooler_size],
|
|
208
|
+
batch_size=batch_size,
|
|
209
|
+
device=device,
|
|
210
|
+
data_type=np.longlong, # input data type
|
|
211
|
+
)
|
|
212
|
+
else:
|
|
213
|
+
result = benchmark_helper.inference_ort(
|
|
214
|
+
ort_session,
|
|
215
|
+
ort_inputs,
|
|
216
|
+
result_template=result_template,
|
|
217
|
+
repeat_times=test_times,
|
|
218
|
+
batch_size=batch_size,
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
# measure result difference between PyTorch and OnnxRuntime
|
|
222
|
+
if not disable_parity:
|
|
223
|
+
diff_results = [
|
|
224
|
+
test_parity(
|
|
225
|
+
device,
|
|
226
|
+
model,
|
|
227
|
+
ort_session,
|
|
228
|
+
batch_size,
|
|
229
|
+
sequence_length,
|
|
230
|
+
global_length,
|
|
231
|
+
verbose,
|
|
232
|
+
)
|
|
233
|
+
for _ in range(test_times)
|
|
234
|
+
]
|
|
235
|
+
|
|
236
|
+
result["diff_max"] = max(diff_results)
|
|
237
|
+
result["diff_90_percentile"] = np.percentile(diff_results, 90)
|
|
238
|
+
result["diff_95_percentile"] = np.percentile(diff_results, 95)
|
|
239
|
+
result["diff_99_percentile"] = np.percentile(diff_results, 99)
|
|
240
|
+
|
|
241
|
+
results.append(result)
|
|
242
|
+
return results
|
|
243
|
+
|
|
244
|
+
|
|
245
|
+
def test_ort_memory(
|
|
246
|
+
device,
|
|
247
|
+
onnx_model_path,
|
|
248
|
+
batch_size,
|
|
249
|
+
sequence_length,
|
|
250
|
+
global_length,
|
|
251
|
+
test_times,
|
|
252
|
+
num_threads,
|
|
253
|
+
) -> dict[str, Any]:
|
|
254
|
+
logger.info(
|
|
255
|
+
f"Testing memory for model={onnx_model_path}, batch_size={batch_size}, sequence_length={sequence_length}, "
|
|
256
|
+
f"global_length={global_length}, test_times={test_times}, num_threads={num_threads}"
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
def inference():
|
|
260
|
+
# Update Arena strategy so that we can measure the minimum memory required
|
|
261
|
+
cuda_provider_options = {"arena_extend_strategy": "kSameAsRequested"}
|
|
262
|
+
provider_options = {"CUDAExecutionProvider": cuda_provider_options}
|
|
263
|
+
session = benchmark_helper.create_onnxruntime_session(
|
|
264
|
+
onnx_model_path,
|
|
265
|
+
use_gpu=True,
|
|
266
|
+
enable_all_optimization=True,
|
|
267
|
+
num_threads=num_threads,
|
|
268
|
+
provider_options=provider_options,
|
|
269
|
+
)
|
|
270
|
+
|
|
271
|
+
dummy_inputs: LongformerInputs = LongformerHelper.get_dummy_inputs(
|
|
272
|
+
batch_size, sequence_length, global_length, device
|
|
273
|
+
)
|
|
274
|
+
ort_inputs = dummy_inputs.get_ort_inputs()
|
|
275
|
+
for _ in range(test_times):
|
|
276
|
+
_ = session.run(None, ort_inputs)
|
|
277
|
+
|
|
278
|
+
memory_used = benchmark_helper.measure_memory(is_gpu=True, func=inference)
|
|
279
|
+
|
|
280
|
+
return {
|
|
281
|
+
"onnx_model": onnx_model_path,
|
|
282
|
+
"batch_size": batch_size,
|
|
283
|
+
"sequence_length": sequence_length,
|
|
284
|
+
"global_length": global_length,
|
|
285
|
+
"test_times": test_times,
|
|
286
|
+
"num_threads": num_threads,
|
|
287
|
+
"memory": memory_used,
|
|
288
|
+
}
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
def load_torch_model(model_name, device):
|
|
292
|
+
torch_model_name_or_dir = PRETRAINED_LONGFORMER_MODELS.get(model_name, model_name)
|
|
293
|
+
model = LongformerModel.from_pretrained(torch_model_name_or_dir)
|
|
294
|
+
model.to(device)
|
|
295
|
+
return model
|
|
296
|
+
|
|
297
|
+
|
|
298
|
+
def find_onnx_model(model_name, onnx_dir="."):
|
|
299
|
+
# Search onnx model in the following order: optimized fp16 model, optimized fp32 model, raw model
|
|
300
|
+
onnx_model_path = os.path.join(onnx_dir, model_name + ".onnx")
|
|
301
|
+
optimized_fp32_model = os.path.join(onnx_dir, model_name + "_fp32.onnx")
|
|
302
|
+
optimized_fp16_model = os.path.join(onnx_dir, model_name + "_fp16.onnx")
|
|
303
|
+
if os.path.isfile(optimized_fp16_model):
|
|
304
|
+
onnx_model_path = optimized_fp16_model
|
|
305
|
+
elif os.path.isfile(optimized_fp32_model):
|
|
306
|
+
onnx_model_path = optimized_fp32_model
|
|
307
|
+
return onnx_model_path
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
def test_memory(args, device) -> dict[str, Any]:
|
|
311
|
+
if len(args.batch_sizes) > 1:
|
|
312
|
+
raise RuntimeError("For memory test, only one batch_size (-b) is allowed.")
|
|
313
|
+
if len(args.sequence_lengths) > 1:
|
|
314
|
+
raise RuntimeError("For memory test, only one sequence_length (-s) is allowed.")
|
|
315
|
+
if len(args.global_lengths) > 1:
|
|
316
|
+
raise RuntimeError("For memory test, only one global_length (-g) is allowed.")
|
|
317
|
+
|
|
318
|
+
model_name = args.model
|
|
319
|
+
onnx_model_path = find_onnx_model(model_name) if not args.onnx else args.onnx
|
|
320
|
+
|
|
321
|
+
torch.cuda.empty_cache()
|
|
322
|
+
return test_ort_memory(
|
|
323
|
+
device,
|
|
324
|
+
onnx_model_path,
|
|
325
|
+
args.batch_sizes[0],
|
|
326
|
+
args.sequence_lengths[0],
|
|
327
|
+
args.global_lengths[0],
|
|
328
|
+
args.test_times,
|
|
329
|
+
args.num_threads,
|
|
330
|
+
)
|
|
331
|
+
|
|
332
|
+
|
|
333
|
+
def test_ort(args, device) -> list[dict[str, Any]]:
|
|
334
|
+
model_name = args.model
|
|
335
|
+
|
|
336
|
+
onnx_model_path = find_onnx_model(model_name) if not args.onnx else args.onnx
|
|
337
|
+
|
|
338
|
+
optimized = onnx_model_path.endswith("_fp16.onnx") or onnx_model_path.endswith("_fp32.onnx") # noqa: PIE810
|
|
339
|
+
precision = "fp32" if not onnx_model_path.endswith("_fp16.onnx") else "fp16"
|
|
340
|
+
|
|
341
|
+
model = load_torch_model(model_name, device)
|
|
342
|
+
|
|
343
|
+
num_threads = args.num_threads
|
|
344
|
+
|
|
345
|
+
cuda_provider_options = {"arena_extend_strategy": "kSameAsRequested"}
|
|
346
|
+
provider_options = {"CUDAExecutionProvider": cuda_provider_options}
|
|
347
|
+
session = benchmark_helper.create_onnxruntime_session(
|
|
348
|
+
onnx_model_path,
|
|
349
|
+
use_gpu=True,
|
|
350
|
+
enable_all_optimization=True,
|
|
351
|
+
num_threads=num_threads,
|
|
352
|
+
provider_options=provider_options,
|
|
353
|
+
)
|
|
354
|
+
if session is None:
|
|
355
|
+
raise RuntimeError(f"Failed to create ORT session from ONNX file {onnx_model_path}")
|
|
356
|
+
|
|
357
|
+
use_compact_memory = os.environ.get("ORT_LONGFORMER_COMPACT_MEMORY", "1") == "1"
|
|
358
|
+
description = onnx_model_path
|
|
359
|
+
if not use_compact_memory:
|
|
360
|
+
description += "[non_compact_memory]"
|
|
361
|
+
|
|
362
|
+
if args.use_half4:
|
|
363
|
+
description += "[half4]" if precision == "fp16" else "[float4]"
|
|
364
|
+
else:
|
|
365
|
+
description += "[half2]" if precision == "fp16" else "[float4]"
|
|
366
|
+
|
|
367
|
+
return test_ort_latency(
|
|
368
|
+
device,
|
|
369
|
+
model,
|
|
370
|
+
model_name,
|
|
371
|
+
description,
|
|
372
|
+
session,
|
|
373
|
+
args.batch_sizes,
|
|
374
|
+
args.sequence_lengths,
|
|
375
|
+
args.global_lengths,
|
|
376
|
+
args.test_times,
|
|
377
|
+
num_threads,
|
|
378
|
+
optimized,
|
|
379
|
+
precision,
|
|
380
|
+
args.disable_io_binding,
|
|
381
|
+
args.verbose,
|
|
382
|
+
use_compact_memory,
|
|
383
|
+
args.use_half4,
|
|
384
|
+
args.disable_parity,
|
|
385
|
+
)
|
|
386
|
+
|
|
387
|
+
|
|
388
|
+
def test_torch(args, device) -> list[dict[str, Any]]:
|
|
389
|
+
model = load_torch_model(args.model, device)
|
|
390
|
+
return test_torch_latency(
|
|
391
|
+
device,
|
|
392
|
+
model,
|
|
393
|
+
args.model,
|
|
394
|
+
args.batch_sizes,
|
|
395
|
+
args.sequence_lengths,
|
|
396
|
+
args.global_lengths,
|
|
397
|
+
args.test_times,
|
|
398
|
+
args.num_threads,
|
|
399
|
+
)
|
|
400
|
+
|
|
401
|
+
|
|
402
|
+
def test_latency(args, device) -> list[dict[str, Any]]:
|
|
403
|
+
if args.engine == "onnxruntime":
|
|
404
|
+
return test_ort(args, device)
|
|
405
|
+
|
|
406
|
+
return test_torch(args, device)
|
|
407
|
+
|
|
408
|
+
|
|
409
|
+
def parse_arguments(argv=None):
|
|
410
|
+
parser = argparse.ArgumentParser()
|
|
411
|
+
|
|
412
|
+
parser.add_argument(
|
|
413
|
+
"-m",
|
|
414
|
+
"--model",
|
|
415
|
+
required=False,
|
|
416
|
+
type=str,
|
|
417
|
+
default="longformer-base-4096",
|
|
418
|
+
help="Checkpoint directory or pre-trained model names in the list: "
|
|
419
|
+
+ ", ".join(PRETRAINED_LONGFORMER_MODELS.keys()),
|
|
420
|
+
)
|
|
421
|
+
|
|
422
|
+
parser.add_argument(
|
|
423
|
+
"-e",
|
|
424
|
+
"--engine",
|
|
425
|
+
required=False,
|
|
426
|
+
type=str,
|
|
427
|
+
default="onnxruntime",
|
|
428
|
+
choices=["onnxruntime", "torch"],
|
|
429
|
+
help="Engine to benchmark.",
|
|
430
|
+
)
|
|
431
|
+
|
|
432
|
+
parser.add_argument(
|
|
433
|
+
"-t",
|
|
434
|
+
"--test_times",
|
|
435
|
+
required=False,
|
|
436
|
+
default=1000,
|
|
437
|
+
type=int,
|
|
438
|
+
help="Number of repeat times to get average inference latency.",
|
|
439
|
+
)
|
|
440
|
+
|
|
441
|
+
parser.add_argument("-b", "--batch_sizes", nargs="+", type=int, default=[1])
|
|
442
|
+
|
|
443
|
+
# If --export_padding is not used in exporting onnx model, there is no padding in ONNX model,
|
|
444
|
+
# and you will need padding inputs by yourself before running onnx model.
|
|
445
|
+
# Here, we only test sequence length that is multiple of attention window size.
|
|
446
|
+
parser.add_argument(
|
|
447
|
+
"-s",
|
|
448
|
+
"--sequence_lengths",
|
|
449
|
+
nargs="+",
|
|
450
|
+
type=int,
|
|
451
|
+
default=[512, 1024, 2048, 4096],
|
|
452
|
+
help="Sequence lengths. It could have multiple values in latency test."
|
|
453
|
+
"If --export_padding is not used, sequence length shall be multiple of window size.",
|
|
454
|
+
)
|
|
455
|
+
|
|
456
|
+
parser.add_argument("--onnx", required=False, type=str, default=None, help="Onnx model path")
|
|
457
|
+
|
|
458
|
+
parser.add_argument(
|
|
459
|
+
"-g",
|
|
460
|
+
"--global_lengths",
|
|
461
|
+
nargs="+",
|
|
462
|
+
type=int,
|
|
463
|
+
default=[0],
|
|
464
|
+
help="Number of global tokens. It could have multiple values in latency test.",
|
|
465
|
+
)
|
|
466
|
+
|
|
467
|
+
parser.add_argument(
|
|
468
|
+
"-n",
|
|
469
|
+
"--num_threads",
|
|
470
|
+
required=False,
|
|
471
|
+
type=int,
|
|
472
|
+
default=0,
|
|
473
|
+
help="Threads to use.",
|
|
474
|
+
)
|
|
475
|
+
|
|
476
|
+
parser.add_argument(
|
|
477
|
+
"--disable_io_binding",
|
|
478
|
+
required=False,
|
|
479
|
+
action="store_true",
|
|
480
|
+
help="Do not use IO Binding.",
|
|
481
|
+
)
|
|
482
|
+
|
|
483
|
+
parser.add_argument(
|
|
484
|
+
"--memory",
|
|
485
|
+
required=False,
|
|
486
|
+
action="store_true",
|
|
487
|
+
help="Test memory usage instead of latency.",
|
|
488
|
+
)
|
|
489
|
+
|
|
490
|
+
parser.add_argument("--verbose", required=False, action="store_true", help="Print more information.")
|
|
491
|
+
parser.set_defaults(verbose=False)
|
|
492
|
+
|
|
493
|
+
parser.add_argument("--use_half4", required=False, action="store_true", help="Use half4 kernel.")
|
|
494
|
+
parser.set_defaults(use_half4=False)
|
|
495
|
+
|
|
496
|
+
parser.add_argument("--disable_parity", required=False, action="store_true", help="Do not run parity test.")
|
|
497
|
+
parser.set_defaults(disable_parity=False)
|
|
498
|
+
|
|
499
|
+
args = parser.parse_args(argv)
|
|
500
|
+
|
|
501
|
+
return args
|
|
502
|
+
|
|
503
|
+
|
|
504
|
+
def output_details(results, csv_filename):
|
|
505
|
+
latency_results = [result for result in results if "average_latency_ms" in result]
|
|
506
|
+
if len(latency_results) == 0:
|
|
507
|
+
print("No latency results for output.")
|
|
508
|
+
return
|
|
509
|
+
|
|
510
|
+
with open(csv_filename, mode="a", newline="", encoding="ascii") as csv_file:
|
|
511
|
+
column_names = [
|
|
512
|
+
"engine",
|
|
513
|
+
"version",
|
|
514
|
+
"device",
|
|
515
|
+
"precision",
|
|
516
|
+
"optimizer",
|
|
517
|
+
"io_binding",
|
|
518
|
+
"model_name",
|
|
519
|
+
"inputs",
|
|
520
|
+
"threads",
|
|
521
|
+
"datetime",
|
|
522
|
+
"test_times",
|
|
523
|
+
"description",
|
|
524
|
+
"batch_size",
|
|
525
|
+
"sequence_length",
|
|
526
|
+
"global_length",
|
|
527
|
+
"use_compact_memory",
|
|
528
|
+
"use_half4",
|
|
529
|
+
"diff_max",
|
|
530
|
+
"diff_90_percentile",
|
|
531
|
+
"diff_95_percentile",
|
|
532
|
+
"diff_99_percentile",
|
|
533
|
+
"memory",
|
|
534
|
+
"QPS",
|
|
535
|
+
"average_latency_ms",
|
|
536
|
+
"latency_variance",
|
|
537
|
+
"latency_90_percentile",
|
|
538
|
+
"latency_95_percentile",
|
|
539
|
+
"latency_99_percentile",
|
|
540
|
+
]
|
|
541
|
+
|
|
542
|
+
csv_writer = csv.DictWriter(csv_file, fieldnames=column_names)
|
|
543
|
+
csv_writer.writeheader()
|
|
544
|
+
for result in latency_results:
|
|
545
|
+
print(result)
|
|
546
|
+
csv_writer.writerow(result)
|
|
547
|
+
|
|
548
|
+
csv_file.flush()
|
|
549
|
+
|
|
550
|
+
print(f"Detail results are saved to csv file: {csv_filename}")
|
|
551
|
+
|
|
552
|
+
|
|
553
|
+
def run(args) -> list[dict[str, Any]]:
|
|
554
|
+
torch.set_grad_enabled(False)
|
|
555
|
+
|
|
556
|
+
# set random seed manually to get deterministic results
|
|
557
|
+
benchmark_helper.set_random_seed(123)
|
|
558
|
+
|
|
559
|
+
# Currently, the longformer attention operator could only run in GPU (no CPU implementation yet).
|
|
560
|
+
device = torch.device("cuda:0")
|
|
561
|
+
|
|
562
|
+
if args.memory:
|
|
563
|
+
return [test_memory(args, device)] # Convert to List so that return type is same as test_latency
|
|
564
|
+
|
|
565
|
+
return test_latency(args, device)
|
|
566
|
+
|
|
567
|
+
|
|
568
|
+
def launch_test(arguments) -> list[dict[str, Any]]:
|
|
569
|
+
if not torch.cuda.is_available():
|
|
570
|
+
raise RuntimeError("Please install PyTorch with Cuda, and use a machine with GPU for testing gpu performance.")
|
|
571
|
+
|
|
572
|
+
with ProcessPoolExecutor() as executor:
|
|
573
|
+
results = list(executor.map(run, [arguments]))
|
|
574
|
+
assert len(results) == 1
|
|
575
|
+
return results[0]
|
|
576
|
+
|
|
577
|
+
|
|
578
|
+
def run_tests(
|
|
579
|
+
use_compact_memory=True,
|
|
580
|
+
run_torch=False,
|
|
581
|
+
run_memory=True,
|
|
582
|
+
use_io_binding=True,
|
|
583
|
+
use_fp16=True,
|
|
584
|
+
use_merged_qkv_weights=True,
|
|
585
|
+
use_half4=True,
|
|
586
|
+
batch_size=1,
|
|
587
|
+
):
|
|
588
|
+
compact_memory = "1" if use_compact_memory else "0"
|
|
589
|
+
os.environ["ORT_LONGFORMER_COMPACT_MEMORY"] = compact_memory
|
|
590
|
+
logger.info(f"ORT_LONGFORMER_COMPACT_MEMORY={compact_memory}")
|
|
591
|
+
|
|
592
|
+
os.environ["ORT_LONGFORMER_USE_HALF4"] = "1" if use_half4 else "0"
|
|
593
|
+
logger.info("ORT_LONGFORMER_USE_HALF4={}".format("1" if use_half4 else "0")) # noqa: G001
|
|
594
|
+
|
|
595
|
+
results = []
|
|
596
|
+
test_times = 1000
|
|
597
|
+
sequence_lengths = [4096, 2048, 1024, 512]
|
|
598
|
+
batch_sizes = [batch_size]
|
|
599
|
+
for model_name in ["longformer-base-4096"]:
|
|
600
|
+
for batch_size in batch_sizes:
|
|
601
|
+
for sequence_length in sequence_lengths:
|
|
602
|
+
for global_length in [16]:
|
|
603
|
+
if run_torch:
|
|
604
|
+
engine_name = "torch"
|
|
605
|
+
args = parse_arguments(
|
|
606
|
+
f"-e {engine_name} -t {test_times} -b {batch_size} -s {sequence_length} -g {global_length} "
|
|
607
|
+
f"-t {test_times} -m {model_name}".split(" ")
|
|
608
|
+
)
|
|
609
|
+
results += run(args)
|
|
610
|
+
|
|
611
|
+
engine_name = "onnxruntime"
|
|
612
|
+
file_format = 1 if use_merged_qkv_weights else 0
|
|
613
|
+
onnx_path = (
|
|
614
|
+
f"{model_name}_f{file_format}_fp16.onnx"
|
|
615
|
+
if use_fp16
|
|
616
|
+
else f"{model_name}_f{file_format}_fp32.onnx"
|
|
617
|
+
)
|
|
618
|
+
if not os.path.exists(onnx_path):
|
|
619
|
+
raise RuntimeError(f"onnx file not exists:{onnx_path}")
|
|
620
|
+
|
|
621
|
+
arguments = (
|
|
622
|
+
f"-e {engine_name} --onnx {onnx_path} "
|
|
623
|
+
f"-b {batch_size} -s {sequence_length} -g {global_length} -m {model_name}"
|
|
624
|
+
)
|
|
625
|
+
|
|
626
|
+
if not use_io_binding:
|
|
627
|
+
arguments += " --disable_io_binding"
|
|
628
|
+
|
|
629
|
+
if use_half4:
|
|
630
|
+
arguments += " --use_half4"
|
|
631
|
+
|
|
632
|
+
# Disable parity test to avoid out of memory for large batch size
|
|
633
|
+
if batch_size >= 4:
|
|
634
|
+
arguments += " --disable_parity"
|
|
635
|
+
|
|
636
|
+
memory_results = None
|
|
637
|
+
try:
|
|
638
|
+
if run_memory:
|
|
639
|
+
args = parse_arguments(f"{arguments} -t 10 --memory".split(" "))
|
|
640
|
+
memory_results = launch_test(args)
|
|
641
|
+
|
|
642
|
+
args = parse_arguments(f"{arguments} -t {test_times}".split(" "))
|
|
643
|
+
latency_results = launch_test(args)
|
|
644
|
+
except KeyboardInterrupt as exc:
|
|
645
|
+
raise RuntimeError("Keyboard Interrupted") from exc
|
|
646
|
+
except Exception:
|
|
647
|
+
traceback.print_exc()
|
|
648
|
+
continue
|
|
649
|
+
|
|
650
|
+
if len(latency_results) == 1:
|
|
651
|
+
latency_results[0]["memory"] = memory_results[0]["memory"] if memory_results else "N/A"
|
|
652
|
+
else:
|
|
653
|
+
raise RuntimeError("length of latency_results should be 1")
|
|
654
|
+
|
|
655
|
+
logger.info("%s", latency_results)
|
|
656
|
+
|
|
657
|
+
results += latency_results
|
|
658
|
+
return results
|
|
659
|
+
|
|
660
|
+
|
|
661
|
+
def output_summary(results, csv_filename, data_field="average_latency_ms"):
|
|
662
|
+
with open(csv_filename, mode="a", newline="", encoding="ascii") as csv_file:
|
|
663
|
+
header_names = [
|
|
664
|
+
"model_name",
|
|
665
|
+
"precision",
|
|
666
|
+
"engine",
|
|
667
|
+
"version",
|
|
668
|
+
"global_length",
|
|
669
|
+
"use_compact_memory",
|
|
670
|
+
"use_half4",
|
|
671
|
+
"description",
|
|
672
|
+
]
|
|
673
|
+
|
|
674
|
+
description_list = list({result["description"] for result in results})
|
|
675
|
+
description_list.sort()
|
|
676
|
+
|
|
677
|
+
batch_sizes = list({result["batch_size"] for result in results})
|
|
678
|
+
batch_sizes.sort()
|
|
679
|
+
|
|
680
|
+
sequence_lengths = list({result["sequence_length"] for result in results})
|
|
681
|
+
sequence_lengths.sort()
|
|
682
|
+
|
|
683
|
+
data_names = []
|
|
684
|
+
for sequence_length in sequence_lengths:
|
|
685
|
+
for batch_size in batch_sizes:
|
|
686
|
+
data_names.append(f"b{batch_size}_s{sequence_length}")
|
|
687
|
+
|
|
688
|
+
csv_writer = csv.DictWriter(csv_file, fieldnames=header_names + data_names)
|
|
689
|
+
csv_writer.writeheader()
|
|
690
|
+
|
|
691
|
+
for description in description_list:
|
|
692
|
+
row = {}
|
|
693
|
+
|
|
694
|
+
sum_latency = {}
|
|
695
|
+
sum_latency.update(dict.fromkeys(data_names, 0))
|
|
696
|
+
|
|
697
|
+
count_latency = {}
|
|
698
|
+
count_latency.update(dict.fromkeys(data_names, 0))
|
|
699
|
+
|
|
700
|
+
for result in results:
|
|
701
|
+
if result["description"] == description and result[data_field]:
|
|
702
|
+
headers = {k: v for k, v in result.items() if k in header_names}
|
|
703
|
+
if not row:
|
|
704
|
+
row.update(headers)
|
|
705
|
+
else:
|
|
706
|
+
for k in header_names:
|
|
707
|
+
if row[k] != headers[k]:
|
|
708
|
+
raise RuntimeError("Description shall be unique")
|
|
709
|
+
|
|
710
|
+
batch_size = result["batch_size"]
|
|
711
|
+
sequence_length = result["sequence_length"]
|
|
712
|
+
key = f"b{batch_size}_s{sequence_length}"
|
|
713
|
+
|
|
714
|
+
try:
|
|
715
|
+
latency = float(result[data_field])
|
|
716
|
+
except ValueError:
|
|
717
|
+
continue
|
|
718
|
+
|
|
719
|
+
sum_latency[key] += latency
|
|
720
|
+
count_latency[key] += 1
|
|
721
|
+
|
|
722
|
+
if row:
|
|
723
|
+
for key in data_names:
|
|
724
|
+
if key in count_latency and count_latency[key] > 0:
|
|
725
|
+
row[key] = sum_latency[key] / count_latency[key]
|
|
726
|
+
else:
|
|
727
|
+
row[key] = ""
|
|
728
|
+
|
|
729
|
+
csv_writer.writerow(row)
|
|
730
|
+
|
|
731
|
+
csv_file.flush()
|
|
732
|
+
|
|
733
|
+
|
|
734
|
+
def run_experiments(use_fp16, batch_size, is_baseline=False):
|
|
735
|
+
"""Run experiments to compare different algorithms on one batch size"""
|
|
736
|
+
test_results = run_tests(
|
|
737
|
+
use_fp16=use_fp16,
|
|
738
|
+
use_merged_qkv_weights=True,
|
|
739
|
+
use_half4=False,
|
|
740
|
+
batch_size=batch_size,
|
|
741
|
+
)
|
|
742
|
+
|
|
743
|
+
if is_baseline:
|
|
744
|
+
return test_results
|
|
745
|
+
|
|
746
|
+
if use_fp16:
|
|
747
|
+
test_results += run_tests(
|
|
748
|
+
use_fp16=use_fp16,
|
|
749
|
+
use_merged_qkv_weights=True,
|
|
750
|
+
use_half4=True,
|
|
751
|
+
batch_size=batch_size,
|
|
752
|
+
)
|
|
753
|
+
|
|
754
|
+
test_results += run_tests(
|
|
755
|
+
use_fp16=use_fp16,
|
|
756
|
+
use_merged_qkv_weights=False,
|
|
757
|
+
use_half4=True,
|
|
758
|
+
batch_size=batch_size,
|
|
759
|
+
)
|
|
760
|
+
|
|
761
|
+
test_results += run_tests(
|
|
762
|
+
use_fp16=use_fp16,
|
|
763
|
+
use_merged_qkv_weights=False,
|
|
764
|
+
use_half4=False,
|
|
765
|
+
batch_size=batch_size,
|
|
766
|
+
)
|
|
767
|
+
|
|
768
|
+
return test_results
|
|
769
|
+
|
|
770
|
+
|
|
771
|
+
def main():
|
|
772
|
+
torch.multiprocessing.set_start_method("spawn")
|
|
773
|
+
|
|
774
|
+
args = parse_arguments()
|
|
775
|
+
|
|
776
|
+
benchmark_helper.setup_logger(args.verbose)
|
|
777
|
+
|
|
778
|
+
if len(sys.argv) > 1:
|
|
779
|
+
test_results = launch_test(args)
|
|
780
|
+
time_stamp = datetime.now().strftime("%Y%m%d-%H%M%S")
|
|
781
|
+
csv_filename = f"benchmark_detail_{time_stamp}.csv"
|
|
782
|
+
output_details(test_results, csv_filename)
|
|
783
|
+
return
|
|
784
|
+
|
|
785
|
+
gpu_list = benchmark_helper.get_gpu_info()
|
|
786
|
+
logger.info("GPU info: %s", gpu_list)
|
|
787
|
+
fp16_batch_sizes = [16, 8, 4, 2, 1]
|
|
788
|
+
fp32_batch_sizes = [4, 2, 1]
|
|
789
|
+
if gpu_list and gpu_list[0]["total"] >= 32 * 1024 * 1024 * 1024: # 32 GB
|
|
790
|
+
fp16_batch_sizes = [64, 32, 16, 8, 4, 2, 1]
|
|
791
|
+
fp32_batch_sizes = [16, 8, 4, 2, 1]
|
|
792
|
+
|
|
793
|
+
gpu_name = re.sub(r"(?u)[^-\w.]", "_", gpu_list[0]["name"]) if gpu_list else "gpu"
|
|
794
|
+
is_baseline = os.environ.get("ORT_LONGFORMER_BASELINE", "0") == "1"
|
|
795
|
+
experiment_name = f"longformer_base_{gpu_name}" + ("_baseline" if is_baseline else "")
|
|
796
|
+
logger.info(
|
|
797
|
+
f"experiment_name={experiment_name}, fp16_batch_sizes={fp16_batch_sizes}, fp32_batch_sizes={fp32_batch_sizes}"
|
|
798
|
+
)
|
|
799
|
+
|
|
800
|
+
total_runs = 1
|
|
801
|
+
all_results = []
|
|
802
|
+
for _ in range(total_runs):
|
|
803
|
+
for batch_size in fp16_batch_sizes:
|
|
804
|
+
fp16_results = run_experiments(use_fp16=True, batch_size=batch_size, is_baseline=is_baseline)
|
|
805
|
+
output_details(fp16_results, "longformer_base_fp16.csv")
|
|
806
|
+
all_results += fp16_results
|
|
807
|
+
for metric_name in ["average_latency_ms", "QPS", "memory", "diff_90_percentile"]:
|
|
808
|
+
output_summary(all_results, f"{experiment_name}_{metric_name}.csv", metric_name)
|
|
809
|
+
|
|
810
|
+
all_results = []
|
|
811
|
+
for _ in range(total_runs):
|
|
812
|
+
for batch_size in fp32_batch_sizes:
|
|
813
|
+
fp32_results = run_experiments(use_fp16=False, batch_size=batch_size, is_baseline=is_baseline)
|
|
814
|
+
output_details(fp32_results, "longformer_base_fp32.csv")
|
|
815
|
+
all_results += fp32_results
|
|
816
|
+
for metric_name in ["average_latency_ms", "QPS", "memory", "diff_90_percentile"]:
|
|
817
|
+
output_summary(all_results, f"{experiment_name}_{metric_name}.csv", metric_name)
|
|
818
|
+
|
|
819
|
+
|
|
820
|
+
if __name__ == "__main__":
|
|
821
|
+
main()
|