onnxruntime-directml 1.20.0__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxruntime/LICENSE +21 -0
- onnxruntime/Privacy.md +21 -0
- onnxruntime/ThirdPartyNotices.txt +6508 -0
- onnxruntime/__init__.py +78 -0
- onnxruntime/backend/__init__.py +6 -0
- onnxruntime/backend/backend.py +174 -0
- onnxruntime/backend/backend_rep.py +53 -0
- onnxruntime/capi/DirectML.dll +0 -0
- onnxruntime/capi/__init__.py +4 -0
- onnxruntime/capi/_ld_preload.py +7 -0
- onnxruntime/capi/_pybind_state.py +33 -0
- onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
- onnxruntime/capi/onnxruntime.dll +0 -0
- onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
- onnxruntime/capi/onnxruntime_inference_collection.py +1108 -0
- onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
- onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
- onnxruntime/capi/onnxruntime_validation.py +150 -0
- onnxruntime/capi/version_info.py +2 -0
- onnxruntime/datasets/__init__.py +17 -0
- onnxruntime/datasets/logreg_iris.onnx +0 -0
- onnxruntime/datasets/mul_1.onnx +0 -0
- onnxruntime/datasets/sigmoid.onnx +13 -0
- onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
- onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
- onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
- onnxruntime/quantization/__init__.py +16 -0
- onnxruntime/quantization/base_quantizer.py +532 -0
- onnxruntime/quantization/calibrate.py +1245 -0
- onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
- onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
- onnxruntime/quantization/execution_providers/qnn/preprocess.py +307 -0
- onnxruntime/quantization/execution_providers/qnn/quant_config.py +387 -0
- onnxruntime/quantization/fusions/__init__.py +3 -0
- onnxruntime/quantization/fusions/fusion.py +311 -0
- onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
- onnxruntime/quantization/fusions/fusion_layernorm.py +135 -0
- onnxruntime/quantization/matmul_4bits_quantizer.py +1480 -0
- onnxruntime/quantization/matmul_bnb4_quantizer.py +240 -0
- onnxruntime/quantization/onnx_model.py +580 -0
- onnxruntime/quantization/onnx_quantizer.py +1008 -0
- onnxruntime/quantization/operators/__init__.py +2 -0
- onnxruntime/quantization/operators/activation.py +119 -0
- onnxruntime/quantization/operators/argmax.py +18 -0
- onnxruntime/quantization/operators/attention.py +73 -0
- onnxruntime/quantization/operators/base_operator.py +26 -0
- onnxruntime/quantization/operators/binary_op.py +72 -0
- onnxruntime/quantization/operators/concat.py +62 -0
- onnxruntime/quantization/operators/conv.py +258 -0
- onnxruntime/quantization/operators/direct_q8.py +78 -0
- onnxruntime/quantization/operators/embed_layernorm.py +121 -0
- onnxruntime/quantization/operators/gather.py +64 -0
- onnxruntime/quantization/operators/gavgpool.py +62 -0
- onnxruntime/quantization/operators/gemm.py +166 -0
- onnxruntime/quantization/operators/lstm.py +117 -0
- onnxruntime/quantization/operators/matmul.py +231 -0
- onnxruntime/quantization/operators/maxpool.py +34 -0
- onnxruntime/quantization/operators/norm.py +40 -0
- onnxruntime/quantization/operators/pad.py +100 -0
- onnxruntime/quantization/operators/pooling.py +67 -0
- onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
- onnxruntime/quantization/operators/resize.py +34 -0
- onnxruntime/quantization/operators/softmax.py +74 -0
- onnxruntime/quantization/operators/split.py +63 -0
- onnxruntime/quantization/operators/where.py +87 -0
- onnxruntime/quantization/preprocess.py +141 -0
- onnxruntime/quantization/qdq_loss_debug.py +389 -0
- onnxruntime/quantization/qdq_quantizer.py +1187 -0
- onnxruntime/quantization/quant_utils.py +891 -0
- onnxruntime/quantization/quantize.py +748 -0
- onnxruntime/quantization/registry.py +106 -0
- onnxruntime/quantization/shape_inference.py +187 -0
- onnxruntime/quantization/tensor_quant_overrides.py +516 -0
- onnxruntime/tools/__init__.py +10 -0
- onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
- onnxruntime/tools/convert_onnx_models_to_ort.py +377 -0
- onnxruntime/tools/file_utils.py +46 -0
- onnxruntime/tools/logger.py +11 -0
- onnxruntime/tools/make_dynamic_shape_fixed.py +72 -0
- onnxruntime/tools/mobile_helpers/__init__.py +0 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +33 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
- onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
- onnxruntime/tools/mobile_helpers/usability_checker.py +739 -0
- onnxruntime/tools/offline_tuning.py +169 -0
- onnxruntime/tools/onnx_model_utils.py +413 -0
- onnxruntime/tools/onnx_randomizer.py +85 -0
- onnxruntime/tools/onnxruntime_test.py +164 -0
- onnxruntime/tools/optimize_onnx_model.py +55 -0
- onnxruntime/tools/ort_format_model/__init__.py +25 -0
- onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +663 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
- onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
- onnxruntime/tools/ort_format_model/types.py +84 -0
- onnxruntime/tools/ort_format_model/utils.py +62 -0
- onnxruntime/tools/pytorch_export_contrib_ops.py +108 -0
- onnxruntime/tools/pytorch_export_helpers.py +131 -0
- onnxruntime/tools/qdq_helpers/__init__.py +0 -0
- onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
- onnxruntime/tools/reduced_build_config_parser.py +202 -0
- onnxruntime/tools/symbolic_shape_infer.py +3016 -0
- onnxruntime/tools/update_onnx_opset.py +31 -0
- onnxruntime/transformers/__init__.py +8 -0
- onnxruntime/transformers/affinity_helper.py +40 -0
- onnxruntime/transformers/benchmark.py +944 -0
- onnxruntime/transformers/benchmark_helper.py +646 -0
- onnxruntime/transformers/bert_perf_test.py +634 -0
- onnxruntime/transformers/bert_test_data.py +642 -0
- onnxruntime/transformers/compare_bert_results.py +246 -0
- onnxruntime/transformers/constants.py +47 -0
- onnxruntime/transformers/convert_generation.py +3124 -0
- onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
- onnxruntime/transformers/convert_to_packing_mode.py +387 -0
- onnxruntime/transformers/dynamo_onnx_helper.py +104 -0
- onnxruntime/transformers/float16.py +501 -0
- onnxruntime/transformers/fusion_attention.py +1235 -0
- onnxruntime/transformers/fusion_attention_clip.py +257 -0
- onnxruntime/transformers/fusion_attention_sam2.py +534 -0
- onnxruntime/transformers/fusion_attention_unet.py +1304 -0
- onnxruntime/transformers/fusion_attention_vae.py +301 -0
- onnxruntime/transformers/fusion_bart_attention.py +640 -0
- onnxruntime/transformers/fusion_base.py +137 -0
- onnxruntime/transformers/fusion_bias_add.py +58 -0
- onnxruntime/transformers/fusion_biasgelu.py +66 -0
- onnxruntime/transformers/fusion_biassplitgelu.py +111 -0
- onnxruntime/transformers/fusion_conformer_attention.py +143 -0
- onnxruntime/transformers/fusion_embedlayer.py +811 -0
- onnxruntime/transformers/fusion_fastgelu.py +360 -0
- onnxruntime/transformers/fusion_gelu.py +259 -0
- onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
- onnxruntime/transformers/fusion_gemmfastgelu.py +122 -0
- onnxruntime/transformers/fusion_gpt_attention.py +546 -0
- onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
- onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
- onnxruntime/transformers/fusion_group_norm.py +179 -0
- onnxruntime/transformers/fusion_layernorm.py +465 -0
- onnxruntime/transformers/fusion_nhwc_conv.py +100 -0
- onnxruntime/transformers/fusion_options.py +340 -0
- onnxruntime/transformers/fusion_qordered_attention.py +421 -0
- onnxruntime/transformers/fusion_qordered_gelu.py +119 -0
- onnxruntime/transformers/fusion_qordered_layernorm.py +123 -0
- onnxruntime/transformers/fusion_qordered_matmul.py +217 -0
- onnxruntime/transformers/fusion_quickgelu.py +74 -0
- onnxruntime/transformers/fusion_reshape.py +173 -0
- onnxruntime/transformers/fusion_rotary_attention.py +1592 -0
- onnxruntime/transformers/fusion_shape.py +110 -0
- onnxruntime/transformers/fusion_simplified_layernorm.py +159 -0
- onnxruntime/transformers/fusion_skip_group_norm.py +255 -0
- onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
- onnxruntime/transformers/fusion_transpose.py +168 -0
- onnxruntime/transformers/fusion_utils.py +307 -0
- onnxruntime/transformers/huggingface_models.py +167 -0
- onnxruntime/transformers/import_utils.py +20 -0
- onnxruntime/transformers/io_binding_helper.py +442 -0
- onnxruntime/transformers/large_model_exporter.py +395 -0
- onnxruntime/transformers/machine_info.py +221 -0
- onnxruntime/transformers/metrics.py +164 -0
- onnxruntime/transformers/models/bart/__init__.py +12 -0
- onnxruntime/transformers/models/bart/export.py +98 -0
- onnxruntime/transformers/models/bert/__init__.py +12 -0
- onnxruntime/transformers/models/bert/eval_squad.py +329 -0
- onnxruntime/transformers/models/gpt2/__init__.py +12 -0
- onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
- onnxruntime/transformers/models/gpt2/convert_to_onnx.py +561 -0
- onnxruntime/transformers/models/gpt2/gpt2_helper.py +1032 -0
- onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
- onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
- onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
- onnxruntime/transformers/models/llama/__init__.py +12 -0
- onnxruntime/transformers/models/llama/benchmark.py +703 -0
- onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
- onnxruntime/transformers/models/llama/benchmark_e2e.py +606 -0
- onnxruntime/transformers/models/llama/convert_to_onnx.py +1027 -0
- onnxruntime/transformers/models/llama/dist_settings.py +57 -0
- onnxruntime/transformers/models/llama/llama_inputs.py +503 -0
- onnxruntime/transformers/models/llama/llama_parity.py +309 -0
- onnxruntime/transformers/models/llama/llama_torch.py +47 -0
- onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
- onnxruntime/transformers/models/longformer/__init__.py +12 -0
- onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
- onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
- onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
- onnxruntime/transformers/models/longformer/longformer_helper.py +77 -0
- onnxruntime/transformers/models/phi2/__init__.py +12 -0
- onnxruntime/transformers/models/phi2/convert_to_onnx.py +576 -0
- onnxruntime/transformers/models/phi2/inference_example.py +414 -0
- onnxruntime/transformers/models/sam2/__init__.py +12 -0
- onnxruntime/transformers/models/sam2/benchmark_sam2.py +625 -0
- onnxruntime/transformers/models/sam2/convert_to_onnx.py +260 -0
- onnxruntime/transformers/models/sam2/image_decoder.py +273 -0
- onnxruntime/transformers/models/sam2/image_encoder.py +186 -0
- onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
- onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
- onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
- onnxruntime/transformers/models/sam2/sam2_demo.py +322 -0
- onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +280 -0
- onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
- onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark.py +1429 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +102 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +268 -0
- onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1319 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1181 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder.py +296 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +388 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
- onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +350 -0
- onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
- onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
- onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
- onnxruntime/transformers/models/t5/__init__.py +12 -0
- onnxruntime/transformers/models/t5/convert_to_onnx.py +278 -0
- onnxruntime/transformers/models/t5/past_helper.py +150 -0
- onnxruntime/transformers/models/t5/t5_decoder.py +438 -0
- onnxruntime/transformers/models/t5/t5_encoder.py +171 -0
- onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +299 -0
- onnxruntime/transformers/models/t5/t5_helper.py +272 -0
- onnxruntime/transformers/models/whisper/__init__.py +12 -0
- onnxruntime/transformers/models/whisper/benchmark.py +610 -0
- onnxruntime/transformers/models/whisper/benchmark_all.py +528 -0
- onnxruntime/transformers/models/whisper/convert_to_onnx.py +536 -0
- onnxruntime/transformers/models/whisper/whisper_chain.py +329 -0
- onnxruntime/transformers/models/whisper/whisper_decoder.py +402 -0
- onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
- onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +306 -0
- onnxruntime/transformers/models/whisper/whisper_helper.py +524 -0
- onnxruntime/transformers/models/whisper/whisper_openai_helper.py +84 -0
- onnxruntime/transformers/onnx_exporter.py +717 -0
- onnxruntime/transformers/onnx_model.py +1569 -0
- onnxruntime/transformers/onnx_model_bart.py +142 -0
- onnxruntime/transformers/onnx_model_bert.py +481 -0
- onnxruntime/transformers/onnx_model_bert_keras.py +475 -0
- onnxruntime/transformers/onnx_model_bert_tf.py +589 -0
- onnxruntime/transformers/onnx_model_clip.py +40 -0
- onnxruntime/transformers/onnx_model_conformer.py +33 -0
- onnxruntime/transformers/onnx_model_gpt2.py +101 -0
- onnxruntime/transformers/onnx_model_phi.py +930 -0
- onnxruntime/transformers/onnx_model_sam2.py +138 -0
- onnxruntime/transformers/onnx_model_t5.py +791 -0
- onnxruntime/transformers/onnx_model_tnlr.py +227 -0
- onnxruntime/transformers/onnx_model_unet.py +259 -0
- onnxruntime/transformers/onnx_model_vae.py +43 -0
- onnxruntime/transformers/onnx_utils.py +55 -0
- onnxruntime/transformers/optimizer.py +612 -0
- onnxruntime/transformers/profiler.py +725 -0
- onnxruntime/transformers/quantize_helper.py +76 -0
- onnxruntime/transformers/shape_infer_helper.py +122 -0
- onnxruntime/transformers/shape_optimizer.py +401 -0
- onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
- onnxruntime_directml-1.20.0.dist-info/METADATA +187 -0
- onnxruntime_directml-1.20.0.dist-info/RECORD +305 -0
- onnxruntime_directml-1.20.0.dist-info/WHEEL +5 -0
- onnxruntime_directml-1.20.0.dist-info/entry_points.txt +2 -0
- onnxruntime_directml-1.20.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,164 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License. See License.txt in the project root for
|
|
4
|
+
# license information.
|
|
5
|
+
# --------------------------------------------------------------------------
|
|
6
|
+
|
|
7
|
+
import datetime
|
|
8
|
+
import json
|
|
9
|
+
from typing import Optional
|
|
10
|
+
|
|
11
|
+
import pandas as pd
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class BaseObject:
|
|
15
|
+
def __init__(self):
|
|
16
|
+
self.customized = {}
|
|
17
|
+
|
|
18
|
+
def to_dict(self):
|
|
19
|
+
default_values = self.__dict__.copy()
|
|
20
|
+
default_values.pop("customized", None)
|
|
21
|
+
default_values.update(self.customized)
|
|
22
|
+
|
|
23
|
+
for k, v in default_values.items():
|
|
24
|
+
if isinstance(v, BaseObject):
|
|
25
|
+
default_values[k] = v.to_dict()
|
|
26
|
+
|
|
27
|
+
return {k: v for k, v in default_values.items() if v}
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class ModelInfo(BaseObject):
|
|
31
|
+
def __init__(
|
|
32
|
+
self,
|
|
33
|
+
full_name: Optional[str] = None,
|
|
34
|
+
is_huggingface: Optional[bool] = False,
|
|
35
|
+
is_text_generation: Optional[bool] = False,
|
|
36
|
+
short_name: Optional[str] = None,
|
|
37
|
+
):
|
|
38
|
+
super().__init__()
|
|
39
|
+
self.full_name = full_name
|
|
40
|
+
self.is_huggingface = is_huggingface
|
|
41
|
+
self.is_text_generation = is_text_generation
|
|
42
|
+
self.short_name = short_name
|
|
43
|
+
self.input_shape = []
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class BackendOptions(BaseObject):
|
|
47
|
+
def __init__(
|
|
48
|
+
self,
|
|
49
|
+
enable_profiling: Optional[bool] = False,
|
|
50
|
+
execution_provider: Optional[str] = None,
|
|
51
|
+
use_io_binding: Optional[bool] = False,
|
|
52
|
+
):
|
|
53
|
+
super().__init__()
|
|
54
|
+
self.enable_profiling = enable_profiling
|
|
55
|
+
self.execution_provider = execution_provider
|
|
56
|
+
self.use_io_binding = use_io_binding
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class Config(BaseObject):
|
|
60
|
+
def __init__(
|
|
61
|
+
self,
|
|
62
|
+
backend: Optional[str] = "onnxruntime",
|
|
63
|
+
batch_size: Optional[int] = 1,
|
|
64
|
+
seq_length: Optional[int] = 0,
|
|
65
|
+
precision: Optional[str] = "fp32",
|
|
66
|
+
warmup_runs: Optional[int] = 1,
|
|
67
|
+
measured_runs: Optional[int] = 10,
|
|
68
|
+
):
|
|
69
|
+
super().__init__()
|
|
70
|
+
self.backend = backend
|
|
71
|
+
self.batch_size = batch_size
|
|
72
|
+
self.seq_length = seq_length
|
|
73
|
+
self.precision = precision
|
|
74
|
+
self.warmup_runs = warmup_runs
|
|
75
|
+
self.measured_runs = measured_runs
|
|
76
|
+
self.model_info = ModelInfo()
|
|
77
|
+
self.backend_options = BackendOptions()
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
class Metadata(BaseObject):
|
|
81
|
+
def __init__(
|
|
82
|
+
self,
|
|
83
|
+
device: Optional[str] = None,
|
|
84
|
+
package_name: Optional[str] = None,
|
|
85
|
+
package_version: Optional[str] = None,
|
|
86
|
+
platform: Optional[str] = None,
|
|
87
|
+
python_version: Optional[str] = None,
|
|
88
|
+
):
|
|
89
|
+
super().__init__()
|
|
90
|
+
self.device = device
|
|
91
|
+
self.package_name = package_name
|
|
92
|
+
self.package_version = package_version
|
|
93
|
+
self.platform = platform
|
|
94
|
+
self.python_version = python_version
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
class Metrics(BaseObject):
|
|
98
|
+
def __init__(
|
|
99
|
+
self,
|
|
100
|
+
latency_ms_mean: Optional[float] = 0.0,
|
|
101
|
+
throughput_qps: Optional[float] = 0.0,
|
|
102
|
+
max_memory_usage_GB: Optional[float] = 0.0,
|
|
103
|
+
):
|
|
104
|
+
super().__init__()
|
|
105
|
+
self.latency_ms_mean = latency_ms_mean
|
|
106
|
+
self.throughput_qps = throughput_qps
|
|
107
|
+
self.max_memory_usage_GB = max_memory_usage_GB
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
class BenchmarkRecord:
|
|
111
|
+
def __init__(
|
|
112
|
+
self,
|
|
113
|
+
model_name: str,
|
|
114
|
+
precision: str,
|
|
115
|
+
backend: str,
|
|
116
|
+
device: str,
|
|
117
|
+
package_name: str,
|
|
118
|
+
package_version: str,
|
|
119
|
+
batch_size: Optional[int] = 1,
|
|
120
|
+
warmup_runs: Optional[int] = 1,
|
|
121
|
+
measured_runs: Optional[int] = 10,
|
|
122
|
+
trigger_date: Optional[str] = None,
|
|
123
|
+
):
|
|
124
|
+
self.config = Config()
|
|
125
|
+
self.metrics = Metrics()
|
|
126
|
+
self.metadata = Metadata()
|
|
127
|
+
self.trigger_date = trigger_date or datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
128
|
+
|
|
129
|
+
self.config.model_info.full_name = model_name
|
|
130
|
+
self.config.precision = precision
|
|
131
|
+
self.config.backend = backend
|
|
132
|
+
self.config.batch_size = batch_size
|
|
133
|
+
self.config.warmup_runs = warmup_runs
|
|
134
|
+
self.config.measured_runs = measured_runs
|
|
135
|
+
self.metadata.device = device
|
|
136
|
+
self.metadata.package_name = package_name
|
|
137
|
+
self.metadata.package_version = package_version
|
|
138
|
+
|
|
139
|
+
def to_dict(self) -> dict:
|
|
140
|
+
return {
|
|
141
|
+
"config": self.config.to_dict(),
|
|
142
|
+
"metadata": self.metadata.to_dict(),
|
|
143
|
+
"metrics": self.metrics.to_dict(),
|
|
144
|
+
"trigger_date": self.trigger_date,
|
|
145
|
+
}
|
|
146
|
+
|
|
147
|
+
def to_json(self) -> str:
|
|
148
|
+
return json.dumps(self.to_dict(), default=str)
|
|
149
|
+
|
|
150
|
+
@classmethod
|
|
151
|
+
def save_as_csv(cls, file_name: str, records: list) -> None:
|
|
152
|
+
if records is None or len(records) == 0:
|
|
153
|
+
return
|
|
154
|
+
rds = [record.to_dict() for record in records]
|
|
155
|
+
df = pd.json_normalize(rds)
|
|
156
|
+
df.to_csv(file_name, index=False)
|
|
157
|
+
|
|
158
|
+
@classmethod
|
|
159
|
+
def save_as_json(cls, file_name: str, records: list) -> None:
|
|
160
|
+
if records is None or len(records) == 0:
|
|
161
|
+
return
|
|
162
|
+
rds = [record.to_dict() for record in records]
|
|
163
|
+
with open(file_name, "w") as f:
|
|
164
|
+
json.dump(rds, f, indent=4, default=str)
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License.
|
|
4
|
+
# --------------------------------------------------------------------------
|
|
5
|
+
import os.path
|
|
6
|
+
import sys
|
|
7
|
+
|
|
8
|
+
sys.path.append(os.path.dirname(__file__))
|
|
9
|
+
|
|
10
|
+
transformers_dir = os.path.normpath(os.path.join(os.path.dirname(__file__), "..", ".."))
|
|
11
|
+
if transformers_dir not in sys.path:
|
|
12
|
+
sys.path.append(transformers_dir)
|
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License. See License.txt in the project root for
|
|
4
|
+
# license information.
|
|
5
|
+
# --------------------------------------------------------------------------
|
|
6
|
+
|
|
7
|
+
import argparse
|
|
8
|
+
import logging
|
|
9
|
+
import os
|
|
10
|
+
import sys
|
|
11
|
+
|
|
12
|
+
from utils import (
|
|
13
|
+
chain_enc_dec_with_beamsearch,
|
|
14
|
+
export_summarization_edinit,
|
|
15
|
+
export_summarization_enc_dec_past,
|
|
16
|
+
onnx_inference,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
# GLOBAL ENVS
|
|
20
|
+
logging.basicConfig(
|
|
21
|
+
format="%(asctime)s | %(levelname)s | %(name)s | [%(filename)s:%(lineno)d] %(message)s",
|
|
22
|
+
datefmt="%Y-%m-%d %H:%M:%S",
|
|
23
|
+
level=os.environ.get("LOGLEVEL", "INFO").upper(),
|
|
24
|
+
stream=sys.stdout,
|
|
25
|
+
)
|
|
26
|
+
logger = logging.getLogger("generate")
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def print_args(args):
|
|
30
|
+
for arg in vars(args):
|
|
31
|
+
logger.info(f"{arg}: {getattr(args, arg)}")
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def user_command():
|
|
35
|
+
parent_parser = argparse.ArgumentParser(add_help=False)
|
|
36
|
+
parent_parser.add_argument("--max_length", type=int, default=20, help="default to 20")
|
|
37
|
+
parent_parser.add_argument("--min_length", type=int, default=0, help="default to 0")
|
|
38
|
+
parent_parser.add_argument("-o", "--output", type=str, default="onnx_models", help="default name is onnx_models.")
|
|
39
|
+
parent_parser.add_argument("-i", "--input_text", type=str, default=None, help="input text")
|
|
40
|
+
parent_parser.add_argument("-s", "--spm_path", type=str, default=None, help="tokenizer model from sentencepice")
|
|
41
|
+
parent_parser.add_argument("-v", "--vocab_path", type=str, help="vocab dictionary")
|
|
42
|
+
parent_parser.add_argument("-b", "--num_beams", type=int, default=5, help="default to 5")
|
|
43
|
+
parent_parser.add_argument("--repetition_penalty", type=float, default=1.0, help="default to 1.0")
|
|
44
|
+
parent_parser.add_argument("--no_repeat_ngram_size", type=int, default=3, help="default to 3")
|
|
45
|
+
parent_parser.add_argument("--early_stopping", type=bool, default=False, help="default to False")
|
|
46
|
+
parent_parser.add_argument("--opset_version", type=int, default=14, help="minimum is 14")
|
|
47
|
+
|
|
48
|
+
parent_parser.add_argument("--no_encoder", action="store_true")
|
|
49
|
+
parent_parser.add_argument("--no_decoder", action="store_true")
|
|
50
|
+
parent_parser.add_argument("--no_chain", action="store_true")
|
|
51
|
+
parent_parser.add_argument("--no_inference", action="store_true")
|
|
52
|
+
|
|
53
|
+
required_args = parent_parser.add_argument_group("required input arguments")
|
|
54
|
+
required_args.add_argument(
|
|
55
|
+
"-m",
|
|
56
|
+
"--model_dir",
|
|
57
|
+
type=str,
|
|
58
|
+
required=True,
|
|
59
|
+
help="The directory contains input huggingface model. \
|
|
60
|
+
An official model like facebook/bart-base is also acceptable.",
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
print_args(parent_parser.parse_args())
|
|
64
|
+
return parent_parser.parse_args()
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
if __name__ == "__main__":
|
|
68
|
+
args = user_command()
|
|
69
|
+
if args.opset_version < 14:
|
|
70
|
+
raise ValueError(f"The minimum supported opset version is 14! The given one was {args.opset_version}.")
|
|
71
|
+
|
|
72
|
+
isExist = os.path.exists(args.output) # noqa: N816
|
|
73
|
+
if not isExist:
|
|
74
|
+
os.makedirs(args.output)
|
|
75
|
+
|
|
76
|
+
# beam search op only supports CPU for now
|
|
77
|
+
args.device = "cpu"
|
|
78
|
+
logger.info("ENV: CPU ...")
|
|
79
|
+
|
|
80
|
+
if not args.input_text:
|
|
81
|
+
args.input_text = (
|
|
82
|
+
"PG&E stated it scheduled the blackouts in response to forecasts for high winds "
|
|
83
|
+
"amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were "
|
|
84
|
+
"scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow."
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
if not args.no_encoder:
|
|
88
|
+
logger.info("========== EXPORTING ENCODER ==========")
|
|
89
|
+
export_summarization_edinit.export_encoder(args)
|
|
90
|
+
if not args.no_decoder:
|
|
91
|
+
logger.info("========== EXPORTING DECODER ==========")
|
|
92
|
+
export_summarization_enc_dec_past.export_decoder(args)
|
|
93
|
+
if not args.no_chain:
|
|
94
|
+
logger.info("========== CONVERTING MODELS ==========")
|
|
95
|
+
chain_enc_dec_with_beamsearch.convert_model(args)
|
|
96
|
+
if not args.no_inference:
|
|
97
|
+
logger.info("========== INFERENCING WITH ONNX MODEL ==========")
|
|
98
|
+
onnx_inference.run_inference(args)
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License.
|
|
4
|
+
# --------------------------------------------------------------------------
|
|
5
|
+
import os.path
|
|
6
|
+
import sys
|
|
7
|
+
|
|
8
|
+
sys.path.append(os.path.dirname(__file__))
|
|
9
|
+
|
|
10
|
+
transformers_dir = os.path.normpath(os.path.join(os.path.dirname(__file__), "..", ".."))
|
|
11
|
+
if transformers_dir not in sys.path:
|
|
12
|
+
sys.path.append(transformers_dir)
|
|
@@ -0,0 +1,329 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License.
|
|
4
|
+
# --------------------------------------------------------------------------
|
|
5
|
+
#
|
|
6
|
+
# This script evaluates accuracy of ONNX models for question-answering task on SQuAD data set.
|
|
7
|
+
# Example to evaluate raw and optimized model for CUDA in Linux:
|
|
8
|
+
# pip3 install datasets evaluate optimum transformers onnxruntime-gpu
|
|
9
|
+
#
|
|
10
|
+
# python3 eval_squad.py -m bert-large-uncased-whole-word-masking-finetuned-squad -s 384 -b 1 --use_io_binding
|
|
11
|
+
#
|
|
12
|
+
# python3 -m onnxruntime.transformers.optimizer \
|
|
13
|
+
# --input ./bert-large-uncased-whole-word-masking-finetuned-squad/model.onnx \
|
|
14
|
+
# --output ./bert-large-uncased-whole-word-masking-finetuned-squad/optimized_model.onnx
|
|
15
|
+
#
|
|
16
|
+
# python3 eval_squad.py -m bert-large-uncased-whole-word-masking-finetuned-squad -s 384 -b 1 --use_io_binding \
|
|
17
|
+
# --onnx ./bert-large-uncased-whole-word-masking-finetuned-squad/optimized_model.onnx
|
|
18
|
+
#
|
|
19
|
+
# Snippet of example output in A100:
|
|
20
|
+
# {'exact': 86.65089877010406, 'f1': 92.99433524952254, 'total': 10570, 'HasAns_exact': 86.65089877010406
|
|
21
|
+
# 'total_time_in_seconds': 81.69239814393222, 'samples_per_second': 129.387804008115,
|
|
22
|
+
# 'latency_in_seconds': 0.007728703703304846, 'provider': 'CUDAExecutionProvider',
|
|
23
|
+
# 'pretrained_model_name': 'bert-large-uncased-whole-word-masking-finetuned-squad',
|
|
24
|
+
# 'batch_size': 1, 'sequence_length': 384, 'use_io_binding': True}
|
|
25
|
+
import argparse
|
|
26
|
+
import csv
|
|
27
|
+
import os
|
|
28
|
+
import time
|
|
29
|
+
|
|
30
|
+
try:
|
|
31
|
+
from importlib.metadata import PackageNotFoundError, version
|
|
32
|
+
except ImportError:
|
|
33
|
+
from importlib_metadata import PackageNotFoundError, version
|
|
34
|
+
|
|
35
|
+
from pathlib import Path
|
|
36
|
+
from typing import Any, Dict, List, Optional
|
|
37
|
+
|
|
38
|
+
from datasets import load_dataset
|
|
39
|
+
from evaluate import evaluator
|
|
40
|
+
from optimum.onnxruntime import ORTModelForQuestionAnswering
|
|
41
|
+
from optimum.version import __version__ as optimum_version
|
|
42
|
+
from packaging import version as version_check
|
|
43
|
+
from transformers import AutoTokenizer, pipeline
|
|
44
|
+
|
|
45
|
+
if version_check.parse(optimum_version) < version_check.parse("1.13.1"):
|
|
46
|
+
raise ImportError(f"Please install optimum>=1.13.1. Current version: {optimum_version}.")
|
|
47
|
+
|
|
48
|
+
PRETRAINED_SQUAD_MODELS = [
|
|
49
|
+
"bert-large-uncased-whole-word-masking-finetuned-squad",
|
|
50
|
+
"deepset/roberta-base-squad2",
|
|
51
|
+
"distilbert-base-cased-distilled-squad",
|
|
52
|
+
]
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def get_package_version(package_name: str):
|
|
56
|
+
try:
|
|
57
|
+
return version(package_name)
|
|
58
|
+
except PackageNotFoundError:
|
|
59
|
+
return None
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def load_onnx_model(
|
|
63
|
+
model_id: str, onnx_path: Optional[str] = None, provider="CUDAExecutionProvider", use_io_binding: bool = False
|
|
64
|
+
):
|
|
65
|
+
"""Load onnx model given pretrained model name and optional ONNX model path. If onnx_path is None,
|
|
66
|
+
the default onnx model from optimum will be used.
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
model_id (str): pretrained model name or checkpoint path
|
|
70
|
+
onnx_path (Optional[str], optional): path of onnx model to evaluate. Defaults to None.
|
|
71
|
+
|
|
72
|
+
Returns:
|
|
73
|
+
model: ORTModel for the onnx model
|
|
74
|
+
onnx_path: the path of onnx model
|
|
75
|
+
"""
|
|
76
|
+
|
|
77
|
+
if onnx_path is None:
|
|
78
|
+
# Export onnx to a sub-directory named by the model id
|
|
79
|
+
model = ORTModelForQuestionAnswering.from_pretrained(
|
|
80
|
+
model_id, export=True, provider=provider, use_io_binding=use_io_binding
|
|
81
|
+
)
|
|
82
|
+
save_onnx_dir = os.path.join(".", model_id)
|
|
83
|
+
model.save_pretrained(save_onnx_dir)
|
|
84
|
+
onnx_path = os.path.join(save_onnx_dir, "model.onnx")
|
|
85
|
+
print("Model is exported to onnx file:", onnx_path)
|
|
86
|
+
else:
|
|
87
|
+
model = ORTModelForQuestionAnswering.from_pretrained(
|
|
88
|
+
os.path.dirname(onnx_path),
|
|
89
|
+
file_name=Path(onnx_path).name,
|
|
90
|
+
provider=provider,
|
|
91
|
+
use_io_binding=use_io_binding,
|
|
92
|
+
# provider_options={"enable_skip_layer_norm_strict_mode": True},
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
return model, onnx_path
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def output_details(results: List[Dict[str, Any]], csv_filename: str):
|
|
99
|
+
"""Output a CSV file with detail of each test results.
|
|
100
|
+
|
|
101
|
+
Args:
|
|
102
|
+
results (List[Dict[str, Any]]): list of JSON results.
|
|
103
|
+
csv_filename (str): path of output CSV file
|
|
104
|
+
"""
|
|
105
|
+
with open(csv_filename, mode="a", newline="", encoding="ascii") as csv_file:
|
|
106
|
+
column_names = [
|
|
107
|
+
"pretrained_model_name",
|
|
108
|
+
"onnx_path",
|
|
109
|
+
"provider",
|
|
110
|
+
"disable_fused_attention",
|
|
111
|
+
"batch_size",
|
|
112
|
+
"sequence_length",
|
|
113
|
+
"use_io_binding",
|
|
114
|
+
"exact",
|
|
115
|
+
"f1",
|
|
116
|
+
"total",
|
|
117
|
+
"HasAns_exact",
|
|
118
|
+
"HasAns_f1",
|
|
119
|
+
"HasAns_total",
|
|
120
|
+
"best_exact",
|
|
121
|
+
"best_exact_thresh",
|
|
122
|
+
"best_f1",
|
|
123
|
+
"best_f1_thresh",
|
|
124
|
+
"total_time_in_seconds",
|
|
125
|
+
"samples_per_second",
|
|
126
|
+
"latency_in_seconds",
|
|
127
|
+
]
|
|
128
|
+
|
|
129
|
+
csv_writer = csv.DictWriter(csv_file, fieldnames=column_names)
|
|
130
|
+
csv_writer.writeheader()
|
|
131
|
+
for result in results:
|
|
132
|
+
csv_writer.writerow(result)
|
|
133
|
+
|
|
134
|
+
csv_file.flush()
|
|
135
|
+
|
|
136
|
+
print(f"Detail results are saved to csv file: {csv_filename}")
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
def output_summary(results: List[Dict[str, Any]], csv_filename: str, metric_name: str):
|
|
140
|
+
"""Output a CSV file with summary of a metric on combinations of batch_size and sequence_length.
|
|
141
|
+
|
|
142
|
+
Args:
|
|
143
|
+
results (List[Dict[str, Any]]): list of JSON results.
|
|
144
|
+
csv_filename (str): path of output CSV file
|
|
145
|
+
metric_name (str): the metric to summarize
|
|
146
|
+
"""
|
|
147
|
+
with open(csv_filename, mode="a", newline="", encoding="ascii") as csv_file:
|
|
148
|
+
header_names = [
|
|
149
|
+
"pretrained_model_name",
|
|
150
|
+
"onnx_path",
|
|
151
|
+
"provider",
|
|
152
|
+
"disable_fused_attention",
|
|
153
|
+
"use_io_binding",
|
|
154
|
+
]
|
|
155
|
+
|
|
156
|
+
model_list = list({result["onnx_path"] for result in results})
|
|
157
|
+
model_list.sort()
|
|
158
|
+
|
|
159
|
+
batch_sizes = list({result["batch_size"] for result in results})
|
|
160
|
+
batch_sizes.sort()
|
|
161
|
+
|
|
162
|
+
sequence_lengths = list({result["sequence_length"] for result in results})
|
|
163
|
+
sequence_lengths.sort()
|
|
164
|
+
|
|
165
|
+
key_names = []
|
|
166
|
+
for sequence_length in sequence_lengths:
|
|
167
|
+
for batch_size in batch_sizes:
|
|
168
|
+
key_names.append(f"b{batch_size}_s{sequence_length}")
|
|
169
|
+
|
|
170
|
+
csv_writer = csv.DictWriter(csv_file, fieldnames=header_names + key_names)
|
|
171
|
+
csv_writer.writeheader()
|
|
172
|
+
|
|
173
|
+
for model in model_list:
|
|
174
|
+
row = {}
|
|
175
|
+
|
|
176
|
+
# Metric value for given pair of batch_size and sequence_length.
|
|
177
|
+
# Assume that (onnx_path, batch_size and sequence_length) are unique so keep first occurrence only.
|
|
178
|
+
values = {}
|
|
179
|
+
values.update({k: "" for k in key_names})
|
|
180
|
+
|
|
181
|
+
for result in results:
|
|
182
|
+
if result["onnx_path"] == model and result[metric_name]:
|
|
183
|
+
headers = {k: v for k, v in result.items() if k in header_names}
|
|
184
|
+
if not row:
|
|
185
|
+
row.update(headers)
|
|
186
|
+
|
|
187
|
+
batch_size = result["batch_size"]
|
|
188
|
+
sequence_length = result["sequence_length"]
|
|
189
|
+
key = f"b{batch_size}_s{sequence_length}"
|
|
190
|
+
|
|
191
|
+
if key in key_names:
|
|
192
|
+
values[key] = result[metric_name]
|
|
193
|
+
|
|
194
|
+
if row:
|
|
195
|
+
for key in key_names:
|
|
196
|
+
row[key] = values.get(key, "")
|
|
197
|
+
csv_writer.writerow(row)
|
|
198
|
+
|
|
199
|
+
csv_file.flush()
|
|
200
|
+
|
|
201
|
+
print(f"Summary results for {metric_name} are saved to csv file: {csv_filename}")
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
def main():
|
|
205
|
+
args = parse_arguments()
|
|
206
|
+
print(args)
|
|
207
|
+
|
|
208
|
+
for name in ["onnxruntime-gpu", "onnxruntime", "onnx", "torch", "transformers", "optimum", "datasets", "evaluate"]:
|
|
209
|
+
package_version = get_package_version(name)
|
|
210
|
+
if package_version:
|
|
211
|
+
print(f"{name} version", package_version)
|
|
212
|
+
|
|
213
|
+
pretrained_model_name = args.model_name
|
|
214
|
+
if args.onnx and not os.path.exists(args.onnx):
|
|
215
|
+
raise RuntimeError(f"Onnx model path does not exist: {args.onnx}")
|
|
216
|
+
|
|
217
|
+
disable_fused_attention = os.environ.get("ORT_DISABLE_FUSED_ATTENTION", "0") == "1"
|
|
218
|
+
|
|
219
|
+
all_results = []
|
|
220
|
+
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name)
|
|
221
|
+
for sequence_length in args.sequence_lengths:
|
|
222
|
+
tokenizer.model_max_length = sequence_length
|
|
223
|
+
tokenizer.doc_stride = min(sequence_length // 2, 128)
|
|
224
|
+
if args.onnx is None:
|
|
225
|
+
print("Exporting onnx model. It might take a few minutes...")
|
|
226
|
+
start_time = time.time()
|
|
227
|
+
ort_model, onnx_path = load_onnx_model(pretrained_model_name, args.onnx, args.provider, args.use_io_binding)
|
|
228
|
+
latency = time.time() - start_time
|
|
229
|
+
print(f"Onnx model exported or loaded in {latency:.1f} seconds")
|
|
230
|
+
|
|
231
|
+
print(ort_model.config)
|
|
232
|
+
if sequence_length > ort_model.config.max_position_embeddings:
|
|
233
|
+
raise RuntimeError("sequence length should not be larger than {ort_model.config.max_position_embeddings}")
|
|
234
|
+
|
|
235
|
+
qa_pipeline = pipeline(
|
|
236
|
+
"question-answering", model=ort_model, tokenizer=tokenizer, question_first=True, batch_size=args.batch_size
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
task_evaluator = evaluator("question-answering")
|
|
240
|
+
print("Loading dataset...")
|
|
241
|
+
start_time = time.time()
|
|
242
|
+
squad_dataset = load_dataset("squad", split=f"validation[:{args.total}]" if args.total > 0 else "validation")
|
|
243
|
+
latency = time.time() - start_time
|
|
244
|
+
print(f"Dataset loaded in {latency:.1f} seconds")
|
|
245
|
+
|
|
246
|
+
print("Evaluating squad_v2 with ORT. It might take a few minutes...")
|
|
247
|
+
start_time = time.time()
|
|
248
|
+
result = task_evaluator.compute(
|
|
249
|
+
model_or_pipeline=qa_pipeline,
|
|
250
|
+
data=squad_dataset,
|
|
251
|
+
metric="squad_v2",
|
|
252
|
+
squad_v2_format=True,
|
|
253
|
+
)
|
|
254
|
+
latency = time.time() - start_time
|
|
255
|
+
print(f"Evaluation done in {latency:.1f} seconds")
|
|
256
|
+
|
|
257
|
+
result["provider"] = args.provider
|
|
258
|
+
result["disable_fused_attention"] = disable_fused_attention
|
|
259
|
+
result["pretrained_model_name"] = pretrained_model_name
|
|
260
|
+
result["onnx_path"] = onnx_path
|
|
261
|
+
result["batch_size"] = args.batch_size
|
|
262
|
+
result["sequence_length"] = sequence_length
|
|
263
|
+
result["use_io_binding"] = args.use_io_binding
|
|
264
|
+
print(result)
|
|
265
|
+
|
|
266
|
+
all_results.append(result)
|
|
267
|
+
|
|
268
|
+
output_details(all_results, "detail.csv")
|
|
269
|
+
|
|
270
|
+
for metric_name in ["f1", "exact", "samples_per_second"]:
|
|
271
|
+
output_summary(all_results, f"{metric_name}.csv", metric_name)
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
def parse_arguments(argv=None):
|
|
275
|
+
parser = argparse.ArgumentParser()
|
|
276
|
+
|
|
277
|
+
parser.add_argument(
|
|
278
|
+
"-m",
|
|
279
|
+
"--model_name",
|
|
280
|
+
required=False,
|
|
281
|
+
type=str,
|
|
282
|
+
default=PRETRAINED_SQUAD_MODELS[0],
|
|
283
|
+
help=f"Checkpoint directory or pre-trained model names in the list: {PRETRAINED_SQUAD_MODELS}",
|
|
284
|
+
)
|
|
285
|
+
|
|
286
|
+
parser.add_argument(
|
|
287
|
+
"-s",
|
|
288
|
+
"--sequence_lengths",
|
|
289
|
+
nargs="+",
|
|
290
|
+
type=int,
|
|
291
|
+
default=[384],
|
|
292
|
+
help="Sequence lengths for onnx model inputs. It could have multiple values.",
|
|
293
|
+
)
|
|
294
|
+
|
|
295
|
+
parser.add_argument(
|
|
296
|
+
"-b",
|
|
297
|
+
"--batch_size",
|
|
298
|
+
type=int,
|
|
299
|
+
default=1,
|
|
300
|
+
help="batch size for inference.",
|
|
301
|
+
)
|
|
302
|
+
|
|
303
|
+
parser.add_argument("-t", "--total", type=int, default=0, help="Total samples to test. 0 means all samples.")
|
|
304
|
+
|
|
305
|
+
parser.add_argument(
|
|
306
|
+
"--onnx",
|
|
307
|
+
required=False,
|
|
308
|
+
type=str,
|
|
309
|
+
default=None,
|
|
310
|
+
help="Optional onnx model path. If not specified, optimum will be used to export onnx model for testing.",
|
|
311
|
+
)
|
|
312
|
+
|
|
313
|
+
parser.add_argument(
|
|
314
|
+
"--provider",
|
|
315
|
+
required=False,
|
|
316
|
+
default="CUDAExecutionProvider",
|
|
317
|
+
help="Select which Execution Provider to use for runs. Default is CUDAExecutionProvider.",
|
|
318
|
+
)
|
|
319
|
+
|
|
320
|
+
parser.add_argument("--use_io_binding", required=False, action="store_true", help="Use IO Binding for GPU.")
|
|
321
|
+
parser.set_defaults(use_io_binding=False)
|
|
322
|
+
|
|
323
|
+
args = parser.parse_args(argv)
|
|
324
|
+
|
|
325
|
+
return args
|
|
326
|
+
|
|
327
|
+
|
|
328
|
+
if __name__ == "__main__":
|
|
329
|
+
main()
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License.
|
|
4
|
+
# --------------------------------------------------------------------------
|
|
5
|
+
import os.path
|
|
6
|
+
import sys
|
|
7
|
+
|
|
8
|
+
sys.path.append(os.path.dirname(__file__))
|
|
9
|
+
|
|
10
|
+
transformers_dir = os.path.normpath(os.path.join(os.path.dirname(__file__), "..", ".."))
|
|
11
|
+
if transformers_dir not in sys.path:
|
|
12
|
+
sys.path.append(transformers_dir)
|