onnxruntime-directml 1.20.0__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxruntime/LICENSE +21 -0
- onnxruntime/Privacy.md +21 -0
- onnxruntime/ThirdPartyNotices.txt +6508 -0
- onnxruntime/__init__.py +78 -0
- onnxruntime/backend/__init__.py +6 -0
- onnxruntime/backend/backend.py +174 -0
- onnxruntime/backend/backend_rep.py +53 -0
- onnxruntime/capi/DirectML.dll +0 -0
- onnxruntime/capi/__init__.py +4 -0
- onnxruntime/capi/_ld_preload.py +7 -0
- onnxruntime/capi/_pybind_state.py +33 -0
- onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
- onnxruntime/capi/onnxruntime.dll +0 -0
- onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
- onnxruntime/capi/onnxruntime_inference_collection.py +1108 -0
- onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
- onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
- onnxruntime/capi/onnxruntime_validation.py +150 -0
- onnxruntime/capi/version_info.py +2 -0
- onnxruntime/datasets/__init__.py +17 -0
- onnxruntime/datasets/logreg_iris.onnx +0 -0
- onnxruntime/datasets/mul_1.onnx +0 -0
- onnxruntime/datasets/sigmoid.onnx +13 -0
- onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
- onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
- onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
- onnxruntime/quantization/__init__.py +16 -0
- onnxruntime/quantization/base_quantizer.py +532 -0
- onnxruntime/quantization/calibrate.py +1245 -0
- onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
- onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
- onnxruntime/quantization/execution_providers/qnn/preprocess.py +307 -0
- onnxruntime/quantization/execution_providers/qnn/quant_config.py +387 -0
- onnxruntime/quantization/fusions/__init__.py +3 -0
- onnxruntime/quantization/fusions/fusion.py +311 -0
- onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
- onnxruntime/quantization/fusions/fusion_layernorm.py +135 -0
- onnxruntime/quantization/matmul_4bits_quantizer.py +1480 -0
- onnxruntime/quantization/matmul_bnb4_quantizer.py +240 -0
- onnxruntime/quantization/onnx_model.py +580 -0
- onnxruntime/quantization/onnx_quantizer.py +1008 -0
- onnxruntime/quantization/operators/__init__.py +2 -0
- onnxruntime/quantization/operators/activation.py +119 -0
- onnxruntime/quantization/operators/argmax.py +18 -0
- onnxruntime/quantization/operators/attention.py +73 -0
- onnxruntime/quantization/operators/base_operator.py +26 -0
- onnxruntime/quantization/operators/binary_op.py +72 -0
- onnxruntime/quantization/operators/concat.py +62 -0
- onnxruntime/quantization/operators/conv.py +258 -0
- onnxruntime/quantization/operators/direct_q8.py +78 -0
- onnxruntime/quantization/operators/embed_layernorm.py +121 -0
- onnxruntime/quantization/operators/gather.py +64 -0
- onnxruntime/quantization/operators/gavgpool.py +62 -0
- onnxruntime/quantization/operators/gemm.py +166 -0
- onnxruntime/quantization/operators/lstm.py +117 -0
- onnxruntime/quantization/operators/matmul.py +231 -0
- onnxruntime/quantization/operators/maxpool.py +34 -0
- onnxruntime/quantization/operators/norm.py +40 -0
- onnxruntime/quantization/operators/pad.py +100 -0
- onnxruntime/quantization/operators/pooling.py +67 -0
- onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
- onnxruntime/quantization/operators/resize.py +34 -0
- onnxruntime/quantization/operators/softmax.py +74 -0
- onnxruntime/quantization/operators/split.py +63 -0
- onnxruntime/quantization/operators/where.py +87 -0
- onnxruntime/quantization/preprocess.py +141 -0
- onnxruntime/quantization/qdq_loss_debug.py +389 -0
- onnxruntime/quantization/qdq_quantizer.py +1187 -0
- onnxruntime/quantization/quant_utils.py +891 -0
- onnxruntime/quantization/quantize.py +748 -0
- onnxruntime/quantization/registry.py +106 -0
- onnxruntime/quantization/shape_inference.py +187 -0
- onnxruntime/quantization/tensor_quant_overrides.py +516 -0
- onnxruntime/tools/__init__.py +10 -0
- onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
- onnxruntime/tools/convert_onnx_models_to_ort.py +377 -0
- onnxruntime/tools/file_utils.py +46 -0
- onnxruntime/tools/logger.py +11 -0
- onnxruntime/tools/make_dynamic_shape_fixed.py +72 -0
- onnxruntime/tools/mobile_helpers/__init__.py +0 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +33 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
- onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
- onnxruntime/tools/mobile_helpers/usability_checker.py +739 -0
- onnxruntime/tools/offline_tuning.py +169 -0
- onnxruntime/tools/onnx_model_utils.py +413 -0
- onnxruntime/tools/onnx_randomizer.py +85 -0
- onnxruntime/tools/onnxruntime_test.py +164 -0
- onnxruntime/tools/optimize_onnx_model.py +55 -0
- onnxruntime/tools/ort_format_model/__init__.py +25 -0
- onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +663 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
- onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
- onnxruntime/tools/ort_format_model/types.py +84 -0
- onnxruntime/tools/ort_format_model/utils.py +62 -0
- onnxruntime/tools/pytorch_export_contrib_ops.py +108 -0
- onnxruntime/tools/pytorch_export_helpers.py +131 -0
- onnxruntime/tools/qdq_helpers/__init__.py +0 -0
- onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
- onnxruntime/tools/reduced_build_config_parser.py +202 -0
- onnxruntime/tools/symbolic_shape_infer.py +3016 -0
- onnxruntime/tools/update_onnx_opset.py +31 -0
- onnxruntime/transformers/__init__.py +8 -0
- onnxruntime/transformers/affinity_helper.py +40 -0
- onnxruntime/transformers/benchmark.py +944 -0
- onnxruntime/transformers/benchmark_helper.py +646 -0
- onnxruntime/transformers/bert_perf_test.py +634 -0
- onnxruntime/transformers/bert_test_data.py +642 -0
- onnxruntime/transformers/compare_bert_results.py +246 -0
- onnxruntime/transformers/constants.py +47 -0
- onnxruntime/transformers/convert_generation.py +3124 -0
- onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
- onnxruntime/transformers/convert_to_packing_mode.py +387 -0
- onnxruntime/transformers/dynamo_onnx_helper.py +104 -0
- onnxruntime/transformers/float16.py +501 -0
- onnxruntime/transformers/fusion_attention.py +1235 -0
- onnxruntime/transformers/fusion_attention_clip.py +257 -0
- onnxruntime/transformers/fusion_attention_sam2.py +534 -0
- onnxruntime/transformers/fusion_attention_unet.py +1304 -0
- onnxruntime/transformers/fusion_attention_vae.py +301 -0
- onnxruntime/transformers/fusion_bart_attention.py +640 -0
- onnxruntime/transformers/fusion_base.py +137 -0
- onnxruntime/transformers/fusion_bias_add.py +58 -0
- onnxruntime/transformers/fusion_biasgelu.py +66 -0
- onnxruntime/transformers/fusion_biassplitgelu.py +111 -0
- onnxruntime/transformers/fusion_conformer_attention.py +143 -0
- onnxruntime/transformers/fusion_embedlayer.py +811 -0
- onnxruntime/transformers/fusion_fastgelu.py +360 -0
- onnxruntime/transformers/fusion_gelu.py +259 -0
- onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
- onnxruntime/transformers/fusion_gemmfastgelu.py +122 -0
- onnxruntime/transformers/fusion_gpt_attention.py +546 -0
- onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
- onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
- onnxruntime/transformers/fusion_group_norm.py +179 -0
- onnxruntime/transformers/fusion_layernorm.py +465 -0
- onnxruntime/transformers/fusion_nhwc_conv.py +100 -0
- onnxruntime/transformers/fusion_options.py +340 -0
- onnxruntime/transformers/fusion_qordered_attention.py +421 -0
- onnxruntime/transformers/fusion_qordered_gelu.py +119 -0
- onnxruntime/transformers/fusion_qordered_layernorm.py +123 -0
- onnxruntime/transformers/fusion_qordered_matmul.py +217 -0
- onnxruntime/transformers/fusion_quickgelu.py +74 -0
- onnxruntime/transformers/fusion_reshape.py +173 -0
- onnxruntime/transformers/fusion_rotary_attention.py +1592 -0
- onnxruntime/transformers/fusion_shape.py +110 -0
- onnxruntime/transformers/fusion_simplified_layernorm.py +159 -0
- onnxruntime/transformers/fusion_skip_group_norm.py +255 -0
- onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
- onnxruntime/transformers/fusion_transpose.py +168 -0
- onnxruntime/transformers/fusion_utils.py +307 -0
- onnxruntime/transformers/huggingface_models.py +167 -0
- onnxruntime/transformers/import_utils.py +20 -0
- onnxruntime/transformers/io_binding_helper.py +442 -0
- onnxruntime/transformers/large_model_exporter.py +395 -0
- onnxruntime/transformers/machine_info.py +221 -0
- onnxruntime/transformers/metrics.py +164 -0
- onnxruntime/transformers/models/bart/__init__.py +12 -0
- onnxruntime/transformers/models/bart/export.py +98 -0
- onnxruntime/transformers/models/bert/__init__.py +12 -0
- onnxruntime/transformers/models/bert/eval_squad.py +329 -0
- onnxruntime/transformers/models/gpt2/__init__.py +12 -0
- onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
- onnxruntime/transformers/models/gpt2/convert_to_onnx.py +561 -0
- onnxruntime/transformers/models/gpt2/gpt2_helper.py +1032 -0
- onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
- onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
- onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
- onnxruntime/transformers/models/llama/__init__.py +12 -0
- onnxruntime/transformers/models/llama/benchmark.py +703 -0
- onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
- onnxruntime/transformers/models/llama/benchmark_e2e.py +606 -0
- onnxruntime/transformers/models/llama/convert_to_onnx.py +1027 -0
- onnxruntime/transformers/models/llama/dist_settings.py +57 -0
- onnxruntime/transformers/models/llama/llama_inputs.py +503 -0
- onnxruntime/transformers/models/llama/llama_parity.py +309 -0
- onnxruntime/transformers/models/llama/llama_torch.py +47 -0
- onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
- onnxruntime/transformers/models/longformer/__init__.py +12 -0
- onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
- onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
- onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
- onnxruntime/transformers/models/longformer/longformer_helper.py +77 -0
- onnxruntime/transformers/models/phi2/__init__.py +12 -0
- onnxruntime/transformers/models/phi2/convert_to_onnx.py +576 -0
- onnxruntime/transformers/models/phi2/inference_example.py +414 -0
- onnxruntime/transformers/models/sam2/__init__.py +12 -0
- onnxruntime/transformers/models/sam2/benchmark_sam2.py +625 -0
- onnxruntime/transformers/models/sam2/convert_to_onnx.py +260 -0
- onnxruntime/transformers/models/sam2/image_decoder.py +273 -0
- onnxruntime/transformers/models/sam2/image_encoder.py +186 -0
- onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
- onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
- onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
- onnxruntime/transformers/models/sam2/sam2_demo.py +322 -0
- onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +280 -0
- onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
- onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark.py +1429 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +102 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +268 -0
- onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1319 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1181 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder.py +296 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +388 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
- onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +350 -0
- onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
- onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
- onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
- onnxruntime/transformers/models/t5/__init__.py +12 -0
- onnxruntime/transformers/models/t5/convert_to_onnx.py +278 -0
- onnxruntime/transformers/models/t5/past_helper.py +150 -0
- onnxruntime/transformers/models/t5/t5_decoder.py +438 -0
- onnxruntime/transformers/models/t5/t5_encoder.py +171 -0
- onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +299 -0
- onnxruntime/transformers/models/t5/t5_helper.py +272 -0
- onnxruntime/transformers/models/whisper/__init__.py +12 -0
- onnxruntime/transformers/models/whisper/benchmark.py +610 -0
- onnxruntime/transformers/models/whisper/benchmark_all.py +528 -0
- onnxruntime/transformers/models/whisper/convert_to_onnx.py +536 -0
- onnxruntime/transformers/models/whisper/whisper_chain.py +329 -0
- onnxruntime/transformers/models/whisper/whisper_decoder.py +402 -0
- onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
- onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +306 -0
- onnxruntime/transformers/models/whisper/whisper_helper.py +524 -0
- onnxruntime/transformers/models/whisper/whisper_openai_helper.py +84 -0
- onnxruntime/transformers/onnx_exporter.py +717 -0
- onnxruntime/transformers/onnx_model.py +1569 -0
- onnxruntime/transformers/onnx_model_bart.py +142 -0
- onnxruntime/transformers/onnx_model_bert.py +481 -0
- onnxruntime/transformers/onnx_model_bert_keras.py +475 -0
- onnxruntime/transformers/onnx_model_bert_tf.py +589 -0
- onnxruntime/transformers/onnx_model_clip.py +40 -0
- onnxruntime/transformers/onnx_model_conformer.py +33 -0
- onnxruntime/transformers/onnx_model_gpt2.py +101 -0
- onnxruntime/transformers/onnx_model_phi.py +930 -0
- onnxruntime/transformers/onnx_model_sam2.py +138 -0
- onnxruntime/transformers/onnx_model_t5.py +791 -0
- onnxruntime/transformers/onnx_model_tnlr.py +227 -0
- onnxruntime/transformers/onnx_model_unet.py +259 -0
- onnxruntime/transformers/onnx_model_vae.py +43 -0
- onnxruntime/transformers/onnx_utils.py +55 -0
- onnxruntime/transformers/optimizer.py +612 -0
- onnxruntime/transformers/profiler.py +725 -0
- onnxruntime/transformers/quantize_helper.py +76 -0
- onnxruntime/transformers/shape_infer_helper.py +122 -0
- onnxruntime/transformers/shape_optimizer.py +401 -0
- onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
- onnxruntime_directml-1.20.0.dist-info/METADATA +187 -0
- onnxruntime_directml-1.20.0.dist-info/RECORD +305 -0
- onnxruntime_directml-1.20.0.dist-info/WHEEL +5 -0
- onnxruntime_directml-1.20.0.dist-info/entry_points.txt +2 -0
- onnxruntime_directml-1.20.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,642 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License.
|
|
4
|
+
# --------------------------------------------------------------------------
|
|
5
|
+
|
|
6
|
+
# It is a tool to generate test data for a bert model.
|
|
7
|
+
# The test data can be used by onnxruntime_perf_test tool to evaluate the inference latency.
|
|
8
|
+
|
|
9
|
+
import argparse
|
|
10
|
+
import os
|
|
11
|
+
import random
|
|
12
|
+
from pathlib import Path
|
|
13
|
+
from typing import Dict, Optional, Tuple
|
|
14
|
+
|
|
15
|
+
import numpy as np
|
|
16
|
+
from onnx import ModelProto, TensorProto, numpy_helper
|
|
17
|
+
from onnx_model import OnnxModel
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def fake_input_ids_data(
|
|
21
|
+
input_ids: TensorProto, batch_size: int, sequence_length: int, dictionary_size: int
|
|
22
|
+
) -> np.ndarray:
|
|
23
|
+
"""Create input tensor based on the graph input of input_ids
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
input_ids (TensorProto): graph input of the input_ids input tensor
|
|
27
|
+
batch_size (int): batch size
|
|
28
|
+
sequence_length (int): sequence length
|
|
29
|
+
dictionary_size (int): vocabulary size of dictionary
|
|
30
|
+
|
|
31
|
+
Returns:
|
|
32
|
+
np.ndarray: the input tensor created
|
|
33
|
+
"""
|
|
34
|
+
assert input_ids.type.tensor_type.elem_type in [
|
|
35
|
+
TensorProto.FLOAT,
|
|
36
|
+
TensorProto.INT32,
|
|
37
|
+
TensorProto.INT64,
|
|
38
|
+
]
|
|
39
|
+
|
|
40
|
+
data = np.random.randint(dictionary_size, size=(batch_size, sequence_length), dtype=np.int32)
|
|
41
|
+
|
|
42
|
+
if input_ids.type.tensor_type.elem_type == TensorProto.FLOAT:
|
|
43
|
+
data = np.float32(data)
|
|
44
|
+
elif input_ids.type.tensor_type.elem_type == TensorProto.INT64:
|
|
45
|
+
data = np.int64(data)
|
|
46
|
+
|
|
47
|
+
return data
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def fake_segment_ids_data(segment_ids: TensorProto, batch_size: int, sequence_length: int) -> np.ndarray:
|
|
51
|
+
"""Create input tensor based on the graph input of segment_ids
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
segment_ids (TensorProto): graph input of the token_type_ids input tensor
|
|
55
|
+
batch_size (int): batch size
|
|
56
|
+
sequence_length (int): sequence length
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
np.ndarray: the input tensor created
|
|
60
|
+
"""
|
|
61
|
+
assert segment_ids.type.tensor_type.elem_type in [
|
|
62
|
+
TensorProto.FLOAT,
|
|
63
|
+
TensorProto.INT32,
|
|
64
|
+
TensorProto.INT64,
|
|
65
|
+
]
|
|
66
|
+
|
|
67
|
+
data = np.zeros((batch_size, sequence_length), dtype=np.int32)
|
|
68
|
+
|
|
69
|
+
if segment_ids.type.tensor_type.elem_type == TensorProto.FLOAT:
|
|
70
|
+
data = np.float32(data)
|
|
71
|
+
elif segment_ids.type.tensor_type.elem_type == TensorProto.INT64:
|
|
72
|
+
data = np.int64(data)
|
|
73
|
+
|
|
74
|
+
return data
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def get_random_length(max_sequence_length: int, average_sequence_length: int):
|
|
78
|
+
assert average_sequence_length >= 1 and average_sequence_length <= max_sequence_length
|
|
79
|
+
|
|
80
|
+
# For uniform distribution, we find proper lower and upper bounds so that the average is in the middle.
|
|
81
|
+
if 2 * average_sequence_length > max_sequence_length:
|
|
82
|
+
return random.randint(2 * average_sequence_length - max_sequence_length, max_sequence_length)
|
|
83
|
+
else:
|
|
84
|
+
return random.randint(1, 2 * average_sequence_length - 1)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def fake_input_mask_data(
|
|
88
|
+
input_mask: TensorProto,
|
|
89
|
+
batch_size: int,
|
|
90
|
+
sequence_length: int,
|
|
91
|
+
average_sequence_length: int,
|
|
92
|
+
random_sequence_length: bool,
|
|
93
|
+
mask_type: int = 2,
|
|
94
|
+
) -> np.ndarray:
|
|
95
|
+
"""Create input tensor based on the graph input of segment_ids.
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
input_mask (TensorProto): graph input of the attention mask input tensor
|
|
99
|
+
batch_size (int): batch size
|
|
100
|
+
sequence_length (int): sequence length
|
|
101
|
+
average_sequence_length (int): average sequence length excluding paddings
|
|
102
|
+
random_sequence_length (bool): whether use uniform random number for sequence length
|
|
103
|
+
mask_type (int): mask type - 1: mask index (sequence length excluding paddings). Shape is (batch_size).
|
|
104
|
+
2: 2D attention mask. Shape is (batch_size, sequence_length).
|
|
105
|
+
3: key len, cumulated lengths of query and key. Shape is (3 * batch_size + 2).
|
|
106
|
+
|
|
107
|
+
Returns:
|
|
108
|
+
np.ndarray: the input tensor created
|
|
109
|
+
"""
|
|
110
|
+
|
|
111
|
+
assert input_mask.type.tensor_type.elem_type in [
|
|
112
|
+
TensorProto.FLOAT,
|
|
113
|
+
TensorProto.INT32,
|
|
114
|
+
TensorProto.INT64,
|
|
115
|
+
]
|
|
116
|
+
|
|
117
|
+
if mask_type == 1: # sequence length excluding paddings
|
|
118
|
+
data = np.ones((batch_size), dtype=np.int32)
|
|
119
|
+
if random_sequence_length:
|
|
120
|
+
for i in range(batch_size):
|
|
121
|
+
data[i] = get_random_length(sequence_length, average_sequence_length)
|
|
122
|
+
else:
|
|
123
|
+
for i in range(batch_size):
|
|
124
|
+
data[i] = average_sequence_length
|
|
125
|
+
elif mask_type == 2: # 2D attention mask
|
|
126
|
+
data = np.zeros((batch_size, sequence_length), dtype=np.int32)
|
|
127
|
+
if random_sequence_length:
|
|
128
|
+
for i in range(batch_size):
|
|
129
|
+
actual_seq_len = get_random_length(sequence_length, average_sequence_length)
|
|
130
|
+
for j in range(actual_seq_len):
|
|
131
|
+
data[i, j] = 1
|
|
132
|
+
else:
|
|
133
|
+
temp = np.ones((batch_size, average_sequence_length), dtype=np.int32)
|
|
134
|
+
data[: temp.shape[0], : temp.shape[1]] = temp
|
|
135
|
+
else:
|
|
136
|
+
assert mask_type == 3
|
|
137
|
+
data = np.zeros((batch_size * 3 + 2), dtype=np.int32)
|
|
138
|
+
if random_sequence_length:
|
|
139
|
+
for i in range(batch_size):
|
|
140
|
+
data[i] = get_random_length(sequence_length, average_sequence_length)
|
|
141
|
+
|
|
142
|
+
for i in range(batch_size + 1):
|
|
143
|
+
data[batch_size + i] = data[batch_size + i - 1] + data[i - 1] if i > 0 else 0
|
|
144
|
+
data[2 * batch_size + 1 + i] = data[batch_size + i - 1] + data[i - 1] if i > 0 else 0
|
|
145
|
+
else:
|
|
146
|
+
for i in range(batch_size):
|
|
147
|
+
data[i] = average_sequence_length
|
|
148
|
+
for i in range(batch_size + 1):
|
|
149
|
+
data[batch_size + i] = i * average_sequence_length
|
|
150
|
+
data[2 * batch_size + 1 + i] = i * average_sequence_length
|
|
151
|
+
|
|
152
|
+
if input_mask.type.tensor_type.elem_type == TensorProto.FLOAT:
|
|
153
|
+
data = np.float32(data)
|
|
154
|
+
elif input_mask.type.tensor_type.elem_type == TensorProto.INT64:
|
|
155
|
+
data = np.int64(data)
|
|
156
|
+
|
|
157
|
+
return data
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
def output_test_data(directory: str, inputs: Dict[str, np.ndarray]):
|
|
161
|
+
"""Output input tensors of test data to a directory
|
|
162
|
+
|
|
163
|
+
Args:
|
|
164
|
+
directory (str): path of a directory
|
|
165
|
+
inputs (Dict[str, np.ndarray]): map from input name to value
|
|
166
|
+
"""
|
|
167
|
+
if not os.path.exists(directory):
|
|
168
|
+
try:
|
|
169
|
+
os.mkdir(directory)
|
|
170
|
+
except OSError:
|
|
171
|
+
print(f"Creation of the directory {directory} failed")
|
|
172
|
+
else:
|
|
173
|
+
print(f"Successfully created the directory {directory} ")
|
|
174
|
+
else:
|
|
175
|
+
print(f"Warning: directory {directory} existed. Files will be overwritten.")
|
|
176
|
+
|
|
177
|
+
for index, (name, data) in enumerate(inputs.items()):
|
|
178
|
+
tensor = numpy_helper.from_array(data, name)
|
|
179
|
+
with open(os.path.join(directory, f"input_{index}.pb"), "wb") as file:
|
|
180
|
+
file.write(tensor.SerializeToString())
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
def fake_test_data(
|
|
184
|
+
batch_size: int,
|
|
185
|
+
sequence_length: int,
|
|
186
|
+
test_cases: int,
|
|
187
|
+
dictionary_size: int,
|
|
188
|
+
verbose: bool,
|
|
189
|
+
random_seed: int,
|
|
190
|
+
input_ids: TensorProto,
|
|
191
|
+
segment_ids: TensorProto,
|
|
192
|
+
input_mask: TensorProto,
|
|
193
|
+
average_sequence_length: int,
|
|
194
|
+
random_sequence_length: bool,
|
|
195
|
+
mask_type: int,
|
|
196
|
+
):
|
|
197
|
+
"""Create given number of input data for testing
|
|
198
|
+
|
|
199
|
+
Args:
|
|
200
|
+
batch_size (int): batch size
|
|
201
|
+
sequence_length (int): sequence length
|
|
202
|
+
test_cases (int): number of test cases
|
|
203
|
+
dictionary_size (int): vocabulary size of dictionary for input_ids
|
|
204
|
+
verbose (bool): print more information or not
|
|
205
|
+
random_seed (int): random seed
|
|
206
|
+
input_ids (TensorProto): graph input of input IDs
|
|
207
|
+
segment_ids (TensorProto): graph input of token type IDs
|
|
208
|
+
input_mask (TensorProto): graph input of attention mask
|
|
209
|
+
average_sequence_length (int): average sequence length excluding paddings
|
|
210
|
+
random_sequence_length (bool): whether use uniform random number for sequence length
|
|
211
|
+
mask_type (int): mask type 1 is mask index; 2 is 2D mask; 3 is key len, cumulated lengths of query and key
|
|
212
|
+
|
|
213
|
+
Returns:
|
|
214
|
+
List[Dict[str,numpy.ndarray]]: list of test cases, where each test case is a dictionary
|
|
215
|
+
with input name as key and a tensor as value
|
|
216
|
+
"""
|
|
217
|
+
assert input_ids is not None
|
|
218
|
+
|
|
219
|
+
np.random.seed(random_seed)
|
|
220
|
+
random.seed(random_seed)
|
|
221
|
+
|
|
222
|
+
all_inputs = []
|
|
223
|
+
for _test_case in range(test_cases):
|
|
224
|
+
input_1 = fake_input_ids_data(input_ids, batch_size, sequence_length, dictionary_size)
|
|
225
|
+
inputs = {input_ids.name: input_1}
|
|
226
|
+
|
|
227
|
+
if segment_ids:
|
|
228
|
+
inputs[segment_ids.name] = fake_segment_ids_data(segment_ids, batch_size, sequence_length)
|
|
229
|
+
|
|
230
|
+
if input_mask:
|
|
231
|
+
inputs[input_mask.name] = fake_input_mask_data(
|
|
232
|
+
input_mask, batch_size, sequence_length, average_sequence_length, random_sequence_length, mask_type
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
if verbose and len(all_inputs) == 0:
|
|
236
|
+
print("Example inputs", inputs)
|
|
237
|
+
all_inputs.append(inputs)
|
|
238
|
+
return all_inputs
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
def generate_test_data(
|
|
242
|
+
batch_size: int,
|
|
243
|
+
sequence_length: int,
|
|
244
|
+
test_cases: int,
|
|
245
|
+
seed: int,
|
|
246
|
+
verbose: bool,
|
|
247
|
+
input_ids: TensorProto,
|
|
248
|
+
segment_ids: TensorProto,
|
|
249
|
+
input_mask: TensorProto,
|
|
250
|
+
average_sequence_length: int,
|
|
251
|
+
random_sequence_length: bool,
|
|
252
|
+
mask_type: int,
|
|
253
|
+
):
|
|
254
|
+
"""Create given number of input data for testing
|
|
255
|
+
|
|
256
|
+
Args:
|
|
257
|
+
batch_size (int): batch size
|
|
258
|
+
sequence_length (int): sequence length
|
|
259
|
+
test_cases (int): number of test cases
|
|
260
|
+
seed (int): random seed
|
|
261
|
+
verbose (bool): print more information or not
|
|
262
|
+
input_ids (TensorProto): graph input of input IDs
|
|
263
|
+
segment_ids (TensorProto): graph input of token type IDs
|
|
264
|
+
input_mask (TensorProto): graph input of attention mask
|
|
265
|
+
average_sequence_length (int): average sequence length excluding paddings
|
|
266
|
+
random_sequence_length (bool): whether use uniform random number for sequence length
|
|
267
|
+
mask_type (int): mask type 1 is mask index; 2 is 2D mask; 3 is key len, cumulated lengths of query and key
|
|
268
|
+
|
|
269
|
+
Returns:
|
|
270
|
+
List[Dict[str,numpy.ndarray]]: list of test cases, where each test case is a dictionary
|
|
271
|
+
with input name as key and a tensor as value
|
|
272
|
+
"""
|
|
273
|
+
dictionary_size = 10000
|
|
274
|
+
all_inputs = fake_test_data(
|
|
275
|
+
batch_size,
|
|
276
|
+
sequence_length,
|
|
277
|
+
test_cases,
|
|
278
|
+
dictionary_size,
|
|
279
|
+
verbose,
|
|
280
|
+
seed,
|
|
281
|
+
input_ids,
|
|
282
|
+
segment_ids,
|
|
283
|
+
input_mask,
|
|
284
|
+
average_sequence_length,
|
|
285
|
+
random_sequence_length,
|
|
286
|
+
mask_type,
|
|
287
|
+
)
|
|
288
|
+
if len(all_inputs) != test_cases:
|
|
289
|
+
print("Failed to create test data for test.")
|
|
290
|
+
return all_inputs
|
|
291
|
+
|
|
292
|
+
|
|
293
|
+
def get_graph_input_from_embed_node(onnx_model, embed_node, input_index):
|
|
294
|
+
if input_index >= len(embed_node.input):
|
|
295
|
+
return None
|
|
296
|
+
|
|
297
|
+
input = embed_node.input[input_index]
|
|
298
|
+
graph_input = onnx_model.find_graph_input(input)
|
|
299
|
+
if graph_input is None:
|
|
300
|
+
parent_node = onnx_model.get_parent(embed_node, input_index)
|
|
301
|
+
if parent_node is not None and parent_node.op_type == "Cast":
|
|
302
|
+
graph_input = onnx_model.find_graph_input(parent_node.input[0])
|
|
303
|
+
return graph_input
|
|
304
|
+
|
|
305
|
+
|
|
306
|
+
def find_bert_inputs(
|
|
307
|
+
onnx_model: OnnxModel,
|
|
308
|
+
input_ids_name: Optional[str] = None,
|
|
309
|
+
segment_ids_name: Optional[str] = None,
|
|
310
|
+
input_mask_name: Optional[str] = None,
|
|
311
|
+
) -> Tuple[Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray]]:
|
|
312
|
+
"""Find graph inputs for BERT model.
|
|
313
|
+
First, we will deduce inputs from EmbedLayerNormalization node.
|
|
314
|
+
If not found, we will guess the meaning of graph inputs based on naming.
|
|
315
|
+
|
|
316
|
+
Args:
|
|
317
|
+
onnx_model (OnnxModel): onnx model object
|
|
318
|
+
input_ids_name (str, optional): Name of graph input for input IDs. Defaults to None.
|
|
319
|
+
segment_ids_name (str, optional): Name of graph input for segment IDs. Defaults to None.
|
|
320
|
+
input_mask_name (str, optional): Name of graph input for attention mask. Defaults to None.
|
|
321
|
+
|
|
322
|
+
Raises:
|
|
323
|
+
ValueError: Graph does not have input named of input_ids_name or segment_ids_name or input_mask_name
|
|
324
|
+
ValueError: Expected graph input number does not match with specified input_ids_name, segment_ids_name
|
|
325
|
+
and input_mask_name
|
|
326
|
+
|
|
327
|
+
Returns:
|
|
328
|
+
Tuple[Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray]]: input tensors of input_ids,
|
|
329
|
+
segment_ids and input_mask
|
|
330
|
+
"""
|
|
331
|
+
|
|
332
|
+
graph_inputs = onnx_model.get_graph_inputs_excluding_initializers()
|
|
333
|
+
|
|
334
|
+
if input_ids_name is not None:
|
|
335
|
+
input_ids = onnx_model.find_graph_input(input_ids_name)
|
|
336
|
+
if input_ids is None:
|
|
337
|
+
raise ValueError(f"Graph does not have input named {input_ids_name}")
|
|
338
|
+
|
|
339
|
+
segment_ids = None
|
|
340
|
+
if segment_ids_name:
|
|
341
|
+
segment_ids = onnx_model.find_graph_input(segment_ids_name)
|
|
342
|
+
if segment_ids is None:
|
|
343
|
+
raise ValueError(f"Graph does not have input named {segment_ids_name}")
|
|
344
|
+
|
|
345
|
+
input_mask = None
|
|
346
|
+
if input_mask_name:
|
|
347
|
+
input_mask = onnx_model.find_graph_input(input_mask_name)
|
|
348
|
+
if input_mask is None:
|
|
349
|
+
raise ValueError(f"Graph does not have input named {input_mask_name}")
|
|
350
|
+
|
|
351
|
+
expected_inputs = 1 + (1 if segment_ids else 0) + (1 if input_mask else 0)
|
|
352
|
+
if len(graph_inputs) != expected_inputs:
|
|
353
|
+
raise ValueError(f"Expect the graph to have {expected_inputs} inputs. Got {len(graph_inputs)}")
|
|
354
|
+
|
|
355
|
+
return input_ids, segment_ids, input_mask
|
|
356
|
+
|
|
357
|
+
if len(graph_inputs) != 3:
|
|
358
|
+
raise ValueError(f"Expect the graph to have 3 inputs. Got {len(graph_inputs)}")
|
|
359
|
+
|
|
360
|
+
embed_nodes = onnx_model.get_nodes_by_op_type("EmbedLayerNormalization")
|
|
361
|
+
if len(embed_nodes) == 1:
|
|
362
|
+
embed_node = embed_nodes[0]
|
|
363
|
+
input_ids = get_graph_input_from_embed_node(onnx_model, embed_node, 0)
|
|
364
|
+
segment_ids = get_graph_input_from_embed_node(onnx_model, embed_node, 1)
|
|
365
|
+
input_mask = get_graph_input_from_embed_node(onnx_model, embed_node, 7)
|
|
366
|
+
|
|
367
|
+
if input_mask is None:
|
|
368
|
+
for input in graph_inputs:
|
|
369
|
+
input_name_lower = input.name.lower()
|
|
370
|
+
if "mask" in input_name_lower:
|
|
371
|
+
input_mask = input
|
|
372
|
+
if input_mask is None:
|
|
373
|
+
raise ValueError("Failed to find attention mask input")
|
|
374
|
+
|
|
375
|
+
return input_ids, segment_ids, input_mask
|
|
376
|
+
|
|
377
|
+
# Try guess the inputs based on naming.
|
|
378
|
+
input_ids = None
|
|
379
|
+
segment_ids = None
|
|
380
|
+
input_mask = None
|
|
381
|
+
for input in graph_inputs:
|
|
382
|
+
input_name_lower = input.name.lower()
|
|
383
|
+
if "mask" in input_name_lower: # matches input with name like "attention_mask" or "input_mask"
|
|
384
|
+
input_mask = input
|
|
385
|
+
elif (
|
|
386
|
+
"token" in input_name_lower or "segment" in input_name_lower
|
|
387
|
+
): # matches input with name like "segment_ids" or "token_type_ids"
|
|
388
|
+
segment_ids = input
|
|
389
|
+
else:
|
|
390
|
+
input_ids = input
|
|
391
|
+
|
|
392
|
+
if input_ids and segment_ids and input_mask:
|
|
393
|
+
return input_ids, segment_ids, input_mask
|
|
394
|
+
|
|
395
|
+
raise ValueError("Fail to assign 3 inputs. You might try rename the graph inputs.")
|
|
396
|
+
|
|
397
|
+
|
|
398
|
+
def get_bert_inputs(
|
|
399
|
+
onnx_file: str,
|
|
400
|
+
input_ids_name: Optional[str] = None,
|
|
401
|
+
segment_ids_name: Optional[str] = None,
|
|
402
|
+
input_mask_name: Optional[str] = None,
|
|
403
|
+
) -> Tuple[Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray]]:
|
|
404
|
+
"""Find graph inputs for BERT model.
|
|
405
|
+
First, we will deduce inputs from EmbedLayerNormalization node.
|
|
406
|
+
If not found, we will guess the meaning of graph inputs based on naming.
|
|
407
|
+
|
|
408
|
+
Args:
|
|
409
|
+
onnx_file (str): onnx model path
|
|
410
|
+
input_ids_name (str, optional): Name of graph input for input IDs. Defaults to None.
|
|
411
|
+
segment_ids_name (str, optional): Name of graph input for segment IDs. Defaults to None.
|
|
412
|
+
input_mask_name (str, optional): Name of graph input for attention mask. Defaults to None.
|
|
413
|
+
|
|
414
|
+
Returns:
|
|
415
|
+
Tuple[Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray]]: input tensors of input_ids,
|
|
416
|
+
segment_ids and input_mask
|
|
417
|
+
"""
|
|
418
|
+
model = ModelProto()
|
|
419
|
+
with open(onnx_file, "rb") as file:
|
|
420
|
+
model.ParseFromString(file.read())
|
|
421
|
+
|
|
422
|
+
onnx_model = OnnxModel(model)
|
|
423
|
+
return find_bert_inputs(onnx_model, input_ids_name, segment_ids_name, input_mask_name)
|
|
424
|
+
|
|
425
|
+
|
|
426
|
+
def parse_arguments():
|
|
427
|
+
parser = argparse.ArgumentParser()
|
|
428
|
+
|
|
429
|
+
parser.add_argument("--model", required=True, type=str, help="bert onnx model path.")
|
|
430
|
+
|
|
431
|
+
parser.add_argument(
|
|
432
|
+
"--output_dir",
|
|
433
|
+
required=False,
|
|
434
|
+
type=str,
|
|
435
|
+
default=None,
|
|
436
|
+
help="output test data path. Default is current directory.",
|
|
437
|
+
)
|
|
438
|
+
|
|
439
|
+
parser.add_argument("--batch_size", required=False, type=int, default=1, help="batch size of input")
|
|
440
|
+
|
|
441
|
+
parser.add_argument(
|
|
442
|
+
"--sequence_length",
|
|
443
|
+
required=False,
|
|
444
|
+
type=int,
|
|
445
|
+
default=128,
|
|
446
|
+
help="maximum sequence length of input",
|
|
447
|
+
)
|
|
448
|
+
|
|
449
|
+
parser.add_argument(
|
|
450
|
+
"--input_ids_name",
|
|
451
|
+
required=False,
|
|
452
|
+
type=str,
|
|
453
|
+
default=None,
|
|
454
|
+
help="input name for input ids",
|
|
455
|
+
)
|
|
456
|
+
parser.add_argument(
|
|
457
|
+
"--segment_ids_name",
|
|
458
|
+
required=False,
|
|
459
|
+
type=str,
|
|
460
|
+
default=None,
|
|
461
|
+
help="input name for segment ids",
|
|
462
|
+
)
|
|
463
|
+
parser.add_argument(
|
|
464
|
+
"--input_mask_name",
|
|
465
|
+
required=False,
|
|
466
|
+
type=str,
|
|
467
|
+
default=None,
|
|
468
|
+
help="input name for attention mask",
|
|
469
|
+
)
|
|
470
|
+
|
|
471
|
+
parser.add_argument(
|
|
472
|
+
"--samples",
|
|
473
|
+
required=False,
|
|
474
|
+
type=int,
|
|
475
|
+
default=1,
|
|
476
|
+
help="number of test cases to be generated",
|
|
477
|
+
)
|
|
478
|
+
|
|
479
|
+
parser.add_argument("--seed", required=False, type=int, default=3, help="random seed")
|
|
480
|
+
|
|
481
|
+
parser.add_argument(
|
|
482
|
+
"--verbose",
|
|
483
|
+
required=False,
|
|
484
|
+
action="store_true",
|
|
485
|
+
help="print verbose information",
|
|
486
|
+
)
|
|
487
|
+
parser.set_defaults(verbose=False)
|
|
488
|
+
|
|
489
|
+
parser.add_argument(
|
|
490
|
+
"--only_input_tensors",
|
|
491
|
+
required=False,
|
|
492
|
+
action="store_true",
|
|
493
|
+
help="only save input tensors and no output tensors",
|
|
494
|
+
)
|
|
495
|
+
parser.set_defaults(only_input_tensors=False)
|
|
496
|
+
|
|
497
|
+
parser.add_argument(
|
|
498
|
+
"-a",
|
|
499
|
+
"--average_sequence_length",
|
|
500
|
+
default=-1,
|
|
501
|
+
type=int,
|
|
502
|
+
help="average sequence length excluding padding",
|
|
503
|
+
)
|
|
504
|
+
|
|
505
|
+
parser.add_argument(
|
|
506
|
+
"-r",
|
|
507
|
+
"--random_sequence_length",
|
|
508
|
+
required=False,
|
|
509
|
+
action="store_true",
|
|
510
|
+
help="use uniform random instead of fixed sequence length",
|
|
511
|
+
)
|
|
512
|
+
parser.set_defaults(random_sequence_length=False)
|
|
513
|
+
|
|
514
|
+
parser.add_argument(
|
|
515
|
+
"--mask_type",
|
|
516
|
+
required=False,
|
|
517
|
+
type=int,
|
|
518
|
+
default=2,
|
|
519
|
+
help="mask type: (1: mask index, 2: raw 2D mask, 3: key lengths, cumulated lengths of query and key)",
|
|
520
|
+
)
|
|
521
|
+
|
|
522
|
+
args = parser.parse_args()
|
|
523
|
+
return args
|
|
524
|
+
|
|
525
|
+
|
|
526
|
+
def create_and_save_test_data(
|
|
527
|
+
model: str,
|
|
528
|
+
output_dir: str,
|
|
529
|
+
batch_size: int,
|
|
530
|
+
sequence_length: int,
|
|
531
|
+
test_cases: int,
|
|
532
|
+
seed: int,
|
|
533
|
+
verbose: bool,
|
|
534
|
+
input_ids_name: Optional[str],
|
|
535
|
+
segment_ids_name: Optional[str],
|
|
536
|
+
input_mask_name: Optional[str],
|
|
537
|
+
only_input_tensors: bool,
|
|
538
|
+
average_sequence_length: int,
|
|
539
|
+
random_sequence_length: bool,
|
|
540
|
+
mask_type: int,
|
|
541
|
+
):
|
|
542
|
+
"""Create test data for a model, and save test data to a directory.
|
|
543
|
+
|
|
544
|
+
Args:
|
|
545
|
+
model (str): path of ONNX bert model
|
|
546
|
+
output_dir (str): output directory
|
|
547
|
+
batch_size (int): batch size
|
|
548
|
+
sequence_length (int): sequence length
|
|
549
|
+
test_cases (int): number of test cases
|
|
550
|
+
seed (int): random seed
|
|
551
|
+
verbose (bool): whether print more information
|
|
552
|
+
input_ids_name (str): graph input name of input_ids
|
|
553
|
+
segment_ids_name (str): graph input name of segment_ids
|
|
554
|
+
input_mask_name (str): graph input name of input_mask
|
|
555
|
+
only_input_tensors (bool): only save input tensors,
|
|
556
|
+
average_sequence_length (int): average sequence length excluding paddings
|
|
557
|
+
random_sequence_length (bool): whether use uniform random number for sequence length
|
|
558
|
+
mask_type(int): mask type
|
|
559
|
+
"""
|
|
560
|
+
input_ids, segment_ids, input_mask = get_bert_inputs(model, input_ids_name, segment_ids_name, input_mask_name)
|
|
561
|
+
|
|
562
|
+
all_inputs = generate_test_data(
|
|
563
|
+
batch_size,
|
|
564
|
+
sequence_length,
|
|
565
|
+
test_cases,
|
|
566
|
+
seed,
|
|
567
|
+
verbose,
|
|
568
|
+
input_ids,
|
|
569
|
+
segment_ids,
|
|
570
|
+
input_mask,
|
|
571
|
+
average_sequence_length,
|
|
572
|
+
random_sequence_length,
|
|
573
|
+
mask_type,
|
|
574
|
+
)
|
|
575
|
+
|
|
576
|
+
for i, inputs in enumerate(all_inputs):
|
|
577
|
+
directory = os.path.join(output_dir, "test_data_set_" + str(i))
|
|
578
|
+
output_test_data(directory, inputs)
|
|
579
|
+
|
|
580
|
+
if only_input_tensors:
|
|
581
|
+
return
|
|
582
|
+
|
|
583
|
+
import onnxruntime
|
|
584
|
+
|
|
585
|
+
providers = (
|
|
586
|
+
["CUDAExecutionProvider", "CPUExecutionProvider"]
|
|
587
|
+
if "CUDAExecutionProvider" in onnxruntime.get_available_providers()
|
|
588
|
+
else ["CPUExecutionProvider"]
|
|
589
|
+
)
|
|
590
|
+
session = onnxruntime.InferenceSession(model, providers=providers)
|
|
591
|
+
output_names = [output.name for output in session.get_outputs()]
|
|
592
|
+
|
|
593
|
+
for i, inputs in enumerate(all_inputs):
|
|
594
|
+
directory = os.path.join(output_dir, "test_data_set_" + str(i))
|
|
595
|
+
result = session.run(output_names, inputs)
|
|
596
|
+
for i, output_name in enumerate(output_names): # noqa: PLW2901
|
|
597
|
+
tensor_result = numpy_helper.from_array(np.asarray(result[i]), output_name)
|
|
598
|
+
with open(os.path.join(directory, f"output_{i}.pb"), "wb") as file:
|
|
599
|
+
file.write(tensor_result.SerializeToString())
|
|
600
|
+
|
|
601
|
+
|
|
602
|
+
def main():
|
|
603
|
+
args = parse_arguments()
|
|
604
|
+
|
|
605
|
+
if args.average_sequence_length <= 0:
|
|
606
|
+
args.average_sequence_length = args.sequence_length
|
|
607
|
+
|
|
608
|
+
output_dir = args.output_dir
|
|
609
|
+
if output_dir is None:
|
|
610
|
+
# Default output directory is a sub-directory under the directory of model.
|
|
611
|
+
p = Path(args.model)
|
|
612
|
+
output_dir = os.path.join(p.parent, f"batch_{args.batch_size}_seq_{args.sequence_length}")
|
|
613
|
+
|
|
614
|
+
if output_dir is not None:
|
|
615
|
+
# create the output directory if not existed
|
|
616
|
+
path = Path(output_dir)
|
|
617
|
+
path.mkdir(parents=True, exist_ok=True)
|
|
618
|
+
else:
|
|
619
|
+
print("Directory existed. test data files will be overwritten.")
|
|
620
|
+
|
|
621
|
+
create_and_save_test_data(
|
|
622
|
+
args.model,
|
|
623
|
+
output_dir,
|
|
624
|
+
args.batch_size,
|
|
625
|
+
args.sequence_length,
|
|
626
|
+
args.samples,
|
|
627
|
+
args.seed,
|
|
628
|
+
args.verbose,
|
|
629
|
+
args.input_ids_name,
|
|
630
|
+
args.segment_ids_name,
|
|
631
|
+
args.input_mask_name,
|
|
632
|
+
args.only_input_tensors,
|
|
633
|
+
args.average_sequence_length,
|
|
634
|
+
args.random_sequence_length,
|
|
635
|
+
args.mask_type,
|
|
636
|
+
)
|
|
637
|
+
|
|
638
|
+
print("Test data is saved to directory:", output_dir)
|
|
639
|
+
|
|
640
|
+
|
|
641
|
+
if __name__ == "__main__":
|
|
642
|
+
main()
|