onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxruntime/LICENSE +21 -0
- onnxruntime/Privacy.md +21 -0
- onnxruntime/ThirdPartyNotices.txt +6121 -0
- onnxruntime/__init__.py +418 -0
- onnxruntime/backend/__init__.py +6 -0
- onnxruntime/backend/backend.py +175 -0
- onnxruntime/backend/backend_rep.py +52 -0
- onnxruntime/capi/DirectML.dll +0 -0
- onnxruntime/capi/__init__.py +4 -0
- onnxruntime/capi/_ld_preload.py +7 -0
- onnxruntime/capi/_pybind_state.py +33 -0
- onnxruntime/capi/build_and_package_info.py +2 -0
- onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
- onnxruntime/capi/onnxruntime.dll +0 -0
- onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
- onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
- onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
- onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
- onnxruntime/capi/onnxruntime_validation.py +154 -0
- onnxruntime/capi/version_info.py +2 -0
- onnxruntime/datasets/__init__.py +18 -0
- onnxruntime/datasets/logreg_iris.onnx +0 -0
- onnxruntime/datasets/mul_1.onnx +0 -0
- onnxruntime/datasets/sigmoid.onnx +13 -0
- onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
- onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
- onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
- onnxruntime/quantization/__init__.py +19 -0
- onnxruntime/quantization/base_quantizer.py +529 -0
- onnxruntime/quantization/calibrate.py +1267 -0
- onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
- onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
- onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
- onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
- onnxruntime/quantization/fusions/__init__.py +4 -0
- onnxruntime/quantization/fusions/fusion.py +311 -0
- onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
- onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
- onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
- onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
- onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
- onnxruntime/quantization/neural_compressor/__init__.py +1 -0
- onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
- onnxruntime/quantization/neural_compressor/util.py +80 -0
- onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
- onnxruntime/quantization/onnx_model.py +600 -0
- onnxruntime/quantization/onnx_quantizer.py +1163 -0
- onnxruntime/quantization/operators/__init__.py +2 -0
- onnxruntime/quantization/operators/activation.py +119 -0
- onnxruntime/quantization/operators/argmax.py +18 -0
- onnxruntime/quantization/operators/attention.py +73 -0
- onnxruntime/quantization/operators/base_operator.py +26 -0
- onnxruntime/quantization/operators/binary_op.py +72 -0
- onnxruntime/quantization/operators/concat.py +62 -0
- onnxruntime/quantization/operators/conv.py +260 -0
- onnxruntime/quantization/operators/direct_q8.py +78 -0
- onnxruntime/quantization/operators/embed_layernorm.py +121 -0
- onnxruntime/quantization/operators/gather.py +64 -0
- onnxruntime/quantization/operators/gavgpool.py +62 -0
- onnxruntime/quantization/operators/gemm.py +172 -0
- onnxruntime/quantization/operators/lstm.py +121 -0
- onnxruntime/quantization/operators/matmul.py +231 -0
- onnxruntime/quantization/operators/maxpool.py +34 -0
- onnxruntime/quantization/operators/norm.py +40 -0
- onnxruntime/quantization/operators/pad.py +172 -0
- onnxruntime/quantization/operators/pooling.py +67 -0
- onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
- onnxruntime/quantization/operators/resize.py +34 -0
- onnxruntime/quantization/operators/softmax.py +74 -0
- onnxruntime/quantization/operators/split.py +63 -0
- onnxruntime/quantization/operators/where.py +87 -0
- onnxruntime/quantization/preprocess.py +141 -0
- onnxruntime/quantization/qdq_loss_debug.py +389 -0
- onnxruntime/quantization/qdq_quantizer.py +1477 -0
- onnxruntime/quantization/quant_utils.py +1051 -0
- onnxruntime/quantization/quantize.py +953 -0
- onnxruntime/quantization/registry.py +110 -0
- onnxruntime/quantization/shape_inference.py +204 -0
- onnxruntime/quantization/static_quantize_runner.py +256 -0
- onnxruntime/quantization/tensor_quant_overrides.py +520 -0
- onnxruntime/tools/__init__.py +10 -0
- onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
- onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
- onnxruntime/tools/file_utils.py +47 -0
- onnxruntime/tools/logger.py +11 -0
- onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
- onnxruntime/tools/mobile_helpers/__init__.py +0 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
- onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
- onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
- onnxruntime/tools/offline_tuning.py +169 -0
- onnxruntime/tools/onnx_model_utils.py +416 -0
- onnxruntime/tools/onnx_randomizer.py +85 -0
- onnxruntime/tools/onnxruntime_test.py +164 -0
- onnxruntime/tools/optimize_onnx_model.py +56 -0
- onnxruntime/tools/ort_format_model/__init__.py +27 -0
- onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
- onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
- onnxruntime/tools/ort_format_model/types.py +85 -0
- onnxruntime/tools/ort_format_model/utils.py +61 -0
- onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
- onnxruntime/tools/pytorch_export_helpers.py +131 -0
- onnxruntime/tools/qdq_helpers/__init__.py +0 -0
- onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
- onnxruntime/tools/qnn/add_trans_cast.py +292 -0
- onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
- onnxruntime/tools/qnn/preprocess.py +165 -0
- onnxruntime/tools/reduced_build_config_parser.py +203 -0
- onnxruntime/tools/remove_initializer_from_input.py +37 -0
- onnxruntime/tools/symbolic_shape_infer.py +3094 -0
- onnxruntime/tools/update_onnx_opset.py +31 -0
- onnxruntime/transformers/__init__.py +8 -0
- onnxruntime/transformers/affinity_helper.py +40 -0
- onnxruntime/transformers/benchmark.py +942 -0
- onnxruntime/transformers/benchmark_helper.py +643 -0
- onnxruntime/transformers/bert_perf_test.py +629 -0
- onnxruntime/transformers/bert_test_data.py +641 -0
- onnxruntime/transformers/compare_bert_results.py +256 -0
- onnxruntime/transformers/constants.py +47 -0
- onnxruntime/transformers/convert_generation.py +3605 -0
- onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
- onnxruntime/transformers/convert_to_packing_mode.py +385 -0
- onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
- onnxruntime/transformers/float16.py +501 -0
- onnxruntime/transformers/fusion_attention.py +1189 -0
- onnxruntime/transformers/fusion_attention_clip.py +340 -0
- onnxruntime/transformers/fusion_attention_sam2.py +533 -0
- onnxruntime/transformers/fusion_attention_unet.py +1307 -0
- onnxruntime/transformers/fusion_attention_vae.py +300 -0
- onnxruntime/transformers/fusion_bart_attention.py +435 -0
- onnxruntime/transformers/fusion_base.py +141 -0
- onnxruntime/transformers/fusion_bias_add.py +57 -0
- onnxruntime/transformers/fusion_biasgelu.py +66 -0
- onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
- onnxruntime/transformers/fusion_conformer_attention.py +222 -0
- onnxruntime/transformers/fusion_constant_fold.py +144 -0
- onnxruntime/transformers/fusion_embedlayer.py +810 -0
- onnxruntime/transformers/fusion_fastgelu.py +492 -0
- onnxruntime/transformers/fusion_gelu.py +258 -0
- onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
- onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
- onnxruntime/transformers/fusion_gpt_attention.py +546 -0
- onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
- onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
- onnxruntime/transformers/fusion_group_norm.py +180 -0
- onnxruntime/transformers/fusion_layernorm.py +489 -0
- onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
- onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
- onnxruntime/transformers/fusion_options.py +340 -0
- onnxruntime/transformers/fusion_qordered_attention.py +420 -0
- onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
- onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
- onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
- onnxruntime/transformers/fusion_quickgelu.py +74 -0
- onnxruntime/transformers/fusion_reshape.py +173 -0
- onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
- onnxruntime/transformers/fusion_shape.py +109 -0
- onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
- onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
- onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
- onnxruntime/transformers/fusion_transpose.py +167 -0
- onnxruntime/transformers/fusion_utils.py +321 -0
- onnxruntime/transformers/huggingface_models.py +74 -0
- onnxruntime/transformers/import_utils.py +20 -0
- onnxruntime/transformers/io_binding_helper.py +487 -0
- onnxruntime/transformers/large_model_exporter.py +395 -0
- onnxruntime/transformers/machine_info.py +230 -0
- onnxruntime/transformers/metrics.py +163 -0
- onnxruntime/transformers/models/bart/__init__.py +12 -0
- onnxruntime/transformers/models/bart/export.py +98 -0
- onnxruntime/transformers/models/bert/__init__.py +12 -0
- onnxruntime/transformers/models/bert/eval_squad.py +329 -0
- onnxruntime/transformers/models/gpt2/__init__.py +12 -0
- onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
- onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
- onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
- onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
- onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
- onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
- onnxruntime/transformers/models/llama/__init__.py +12 -0
- onnxruntime/transformers/models/llama/benchmark.py +700 -0
- onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
- onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
- onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
- onnxruntime/transformers/models/llama/dist_settings.py +57 -0
- onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
- onnxruntime/transformers/models/llama/llama_parity.py +343 -0
- onnxruntime/transformers/models/llama/llama_torch.py +47 -0
- onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
- onnxruntime/transformers/models/longformer/__init__.py +12 -0
- onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
- onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
- onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
- onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
- onnxruntime/transformers/models/phi2/__init__.py +12 -0
- onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
- onnxruntime/transformers/models/phi2/inference_example.py +414 -0
- onnxruntime/transformers/models/sam2/__init__.py +12 -0
- onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
- onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
- onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
- onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
- onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
- onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
- onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
- onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
- onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
- onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
- onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
- onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
- onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
- onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
- onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
- onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
- onnxruntime/transformers/models/t5/__init__.py +12 -0
- onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
- onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
- onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
- onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
- onnxruntime/transformers/models/t5/t5_helper.py +302 -0
- onnxruntime/transformers/models/whisper/__init__.py +12 -0
- onnxruntime/transformers/models/whisper/benchmark.py +585 -0
- onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
- onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
- onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
- onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
- onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
- onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
- onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
- onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
- onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
- onnxruntime/transformers/onnx_exporter.py +719 -0
- onnxruntime/transformers/onnx_model.py +1636 -0
- onnxruntime/transformers/onnx_model_bart.py +141 -0
- onnxruntime/transformers/onnx_model_bert.py +488 -0
- onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
- onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
- onnxruntime/transformers/onnx_model_clip.py +42 -0
- onnxruntime/transformers/onnx_model_conformer.py +32 -0
- onnxruntime/transformers/onnx_model_gpt2.py +101 -0
- onnxruntime/transformers/onnx_model_mmdit.py +112 -0
- onnxruntime/transformers/onnx_model_phi.py +929 -0
- onnxruntime/transformers/onnx_model_sam2.py +137 -0
- onnxruntime/transformers/onnx_model_t5.py +985 -0
- onnxruntime/transformers/onnx_model_tnlr.py +226 -0
- onnxruntime/transformers/onnx_model_unet.py +258 -0
- onnxruntime/transformers/onnx_model_vae.py +42 -0
- onnxruntime/transformers/onnx_utils.py +55 -0
- onnxruntime/transformers/optimizer.py +620 -0
- onnxruntime/transformers/past_helper.py +149 -0
- onnxruntime/transformers/profile_result_processor.py +358 -0
- onnxruntime/transformers/profiler.py +434 -0
- onnxruntime/transformers/quantize_helper.py +76 -0
- onnxruntime/transformers/shape_infer_helper.py +121 -0
- onnxruntime/transformers/shape_optimizer.py +400 -0
- onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
- onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
- onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
- onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
- onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
- onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,700 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License. See License.txt in the project root for
|
|
4
|
+
# license information.
|
|
5
|
+
# --------------------------------------------------------------------------
|
|
6
|
+
import argparse
|
|
7
|
+
import datetime
|
|
8
|
+
import gc
|
|
9
|
+
import itertools
|
|
10
|
+
import logging
|
|
11
|
+
import os
|
|
12
|
+
import sys
|
|
13
|
+
import time
|
|
14
|
+
|
|
15
|
+
import numpy as np
|
|
16
|
+
import onnx
|
|
17
|
+
import psutil
|
|
18
|
+
import torch
|
|
19
|
+
from benchmark_helper import measure_memory, setup_logger
|
|
20
|
+
from dist_settings import get_rank, get_size
|
|
21
|
+
from llama_inputs import (
|
|
22
|
+
add_io_bindings_as_ortvalues,
|
|
23
|
+
get_merged_sample_with_past_kv_inputs,
|
|
24
|
+
get_msft_sample_inputs,
|
|
25
|
+
get_sample_inputs,
|
|
26
|
+
get_sample_with_past_kv_inputs,
|
|
27
|
+
verify_ort_inputs,
|
|
28
|
+
)
|
|
29
|
+
from optimum.onnxruntime import ORTModelForCausalLM
|
|
30
|
+
from torch.profiler import ProfilerActivity, profile, record_function
|
|
31
|
+
from tqdm import trange
|
|
32
|
+
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
|
|
33
|
+
|
|
34
|
+
import onnxruntime as ort
|
|
35
|
+
|
|
36
|
+
logger = logging.getLogger(__name__)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
# For determining whether the ONNX model can do both prompt generation and token generation or only one of the two
|
|
40
|
+
def get_ort_model_inputs_len(args, model):
|
|
41
|
+
if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile"}:
|
|
42
|
+
return 0
|
|
43
|
+
if args.benchmark_type == "hf-ort":
|
|
44
|
+
try:
|
|
45
|
+
# New Optimum export (https://github.com/huggingface/optimum/blob/888332364c2e0091da1fc974737c7e277af168bf/optimum/onnxruntime/modeling_ort.py#L268)
|
|
46
|
+
return len(model.inputs_names)
|
|
47
|
+
except Exception:
|
|
48
|
+
# Old Optimum export (https://github.com/huggingface/optimum/blob/c5ad7f971cb0a494eac03dc0909f146725f999c5/optimum/onnxruntime/base.py#L54)
|
|
49
|
+
return len(model.decoder.input_names)
|
|
50
|
+
return len(model.get_inputs())
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int):
|
|
54
|
+
init_inputs, iter_inputs = None, None
|
|
55
|
+
|
|
56
|
+
# For past_present_share_buffer:
|
|
57
|
+
# Set max_seq_len to 2048 for Microsoft LLaMA-2 model since that is the max value currently supported
|
|
58
|
+
# Set max_seq_len to config value for other models
|
|
59
|
+
max_seq_len = 2048 if args.benchmark_type == "ort-msft" else args.config.max_position_embeddings
|
|
60
|
+
|
|
61
|
+
if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile"}:
|
|
62
|
+
init_inputs = get_sample_inputs(
|
|
63
|
+
args.config,
|
|
64
|
+
args.target_device,
|
|
65
|
+
args.batch_size,
|
|
66
|
+
args.sequence_length,
|
|
67
|
+
return_dict=True,
|
|
68
|
+
)
|
|
69
|
+
iter_inputs = get_sample_with_past_kv_inputs(
|
|
70
|
+
args.config,
|
|
71
|
+
args.target_device,
|
|
72
|
+
args.batch_size,
|
|
73
|
+
args.sequence_length,
|
|
74
|
+
use_fp16=args.use_fp16,
|
|
75
|
+
return_dict=True,
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
elif args.benchmark_type in {"hf-ort"}:
|
|
79
|
+
if ort_model_inputs_len == 3: # [input_ids, attention_mask, position_ids]
|
|
80
|
+
# Using split models in Optimum (e.g. created by Optimum export)
|
|
81
|
+
init_inputs = get_sample_inputs(
|
|
82
|
+
args.config,
|
|
83
|
+
args.target_device,
|
|
84
|
+
args.batch_size,
|
|
85
|
+
args.sequence_length,
|
|
86
|
+
return_dict=True,
|
|
87
|
+
)
|
|
88
|
+
iter_inputs = get_sample_with_past_kv_inputs(
|
|
89
|
+
args.config,
|
|
90
|
+
args.target_device,
|
|
91
|
+
args.batch_size,
|
|
92
|
+
args.sequence_length,
|
|
93
|
+
use_fp16=args.use_fp16,
|
|
94
|
+
return_dict=True,
|
|
95
|
+
)
|
|
96
|
+
else:
|
|
97
|
+
# Using merged model in Optimum (e.g. created by convert_to_onnx export)
|
|
98
|
+
init_inputs = get_merged_sample_with_past_kv_inputs(
|
|
99
|
+
args.config,
|
|
100
|
+
args.target_device,
|
|
101
|
+
args.batch_size,
|
|
102
|
+
seq_len=args.sequence_length,
|
|
103
|
+
past_seq_len=0,
|
|
104
|
+
max_seq_len=max_seq_len,
|
|
105
|
+
use_fp16=args.use_fp16,
|
|
106
|
+
use_buffer_share=args.use_buffer_share,
|
|
107
|
+
engine="pt",
|
|
108
|
+
return_dict=True,
|
|
109
|
+
)
|
|
110
|
+
iter_inputs = get_merged_sample_with_past_kv_inputs(
|
|
111
|
+
args.config,
|
|
112
|
+
args.target_device,
|
|
113
|
+
args.batch_size,
|
|
114
|
+
seq_len=1,
|
|
115
|
+
past_seq_len=args.sequence_length,
|
|
116
|
+
max_seq_len=max_seq_len,
|
|
117
|
+
use_fp16=args.use_fp16,
|
|
118
|
+
use_buffer_share=args.use_buffer_share,
|
|
119
|
+
engine="pt",
|
|
120
|
+
return_dict=True,
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
elif args.benchmark_type == "ort-convert-to-onnx":
|
|
124
|
+
# Microsoft export from convert_to_onnx
|
|
125
|
+
init_inputs = get_merged_sample_with_past_kv_inputs(
|
|
126
|
+
args.config,
|
|
127
|
+
args.target_device,
|
|
128
|
+
args.batch_size,
|
|
129
|
+
seq_len=args.sequence_length,
|
|
130
|
+
past_seq_len=0,
|
|
131
|
+
max_seq_len=max_seq_len,
|
|
132
|
+
use_fp16=args.use_fp16,
|
|
133
|
+
use_buffer_share=args.use_buffer_share,
|
|
134
|
+
engine="ort",
|
|
135
|
+
return_dict=True,
|
|
136
|
+
world_size=args.world_size,
|
|
137
|
+
)
|
|
138
|
+
iter_inputs = get_merged_sample_with_past_kv_inputs(
|
|
139
|
+
args.config,
|
|
140
|
+
args.target_device,
|
|
141
|
+
args.batch_size,
|
|
142
|
+
seq_len=1,
|
|
143
|
+
past_seq_len=args.sequence_length,
|
|
144
|
+
max_seq_len=max_seq_len,
|
|
145
|
+
use_fp16=args.use_fp16,
|
|
146
|
+
use_buffer_share=args.use_buffer_share,
|
|
147
|
+
engine="ort",
|
|
148
|
+
return_dict=True,
|
|
149
|
+
world_size=args.world_size,
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
elif args.benchmark_type == "ort-msft":
|
|
153
|
+
# Microsoft export from https://github.com/microsoft/Llama-2-Onnx
|
|
154
|
+
split_kv = ort_model_inputs_len > 5 # original inputs: [x, attn_mask, k_cache, v_cache, pos]
|
|
155
|
+
|
|
156
|
+
init_inputs = get_msft_sample_inputs(
|
|
157
|
+
args.config,
|
|
158
|
+
args.batch_size,
|
|
159
|
+
past_seq_len=0,
|
|
160
|
+
seq_len=args.sequence_length,
|
|
161
|
+
max_seq_len=max_seq_len,
|
|
162
|
+
use_fp16=args.use_fp16,
|
|
163
|
+
use_buffer_share=args.use_buffer_share,
|
|
164
|
+
split_kv=split_kv,
|
|
165
|
+
)
|
|
166
|
+
iter_inputs = get_msft_sample_inputs(
|
|
167
|
+
args.config,
|
|
168
|
+
args.batch_size,
|
|
169
|
+
past_seq_len=args.sequence_length,
|
|
170
|
+
seq_len=1,
|
|
171
|
+
max_seq_len=max_seq_len,
|
|
172
|
+
use_fp16=args.use_fp16,
|
|
173
|
+
use_buffer_share=args.use_buffer_share,
|
|
174
|
+
split_kv=split_kv,
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
else:
|
|
178
|
+
raise Exception("Unable to auto-detect inputs for provided model")
|
|
179
|
+
|
|
180
|
+
return init_inputs, iter_inputs
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
def get_model(args: argparse.Namespace):
|
|
184
|
+
model, sess_options = None, None
|
|
185
|
+
start_time, end_time = None, None
|
|
186
|
+
|
|
187
|
+
# There are multiple sources that the model could come from:
|
|
188
|
+
# 1) Benchmark LLaMA-2 from unofficial source on Hugging Face
|
|
189
|
+
# 2) Benchmark LLaMA-2 from official source on Hugging Face, which requires an authentication token
|
|
190
|
+
# 3) Benchmark LLaMA-2 from local download of model
|
|
191
|
+
# 4) Benchmark LLaMA-2 from Microsoft (already optimized, available at https://github.com/microsoft/Llama-2-Onnx)
|
|
192
|
+
# 5) Benchmark LLaMA-2 from convert_to_onnx
|
|
193
|
+
|
|
194
|
+
if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile"}:
|
|
195
|
+
source = args.hf_pt_dir_path if args.hf_pt_dir_path else args.model_name
|
|
196
|
+
start_time = time.time()
|
|
197
|
+
model = AutoModelForCausalLM.from_pretrained(
|
|
198
|
+
source,
|
|
199
|
+
torch_dtype=torch.float16 if args.use_fp16 else torch.float32,
|
|
200
|
+
use_auth_token=args.auth,
|
|
201
|
+
trust_remote_code=args.auth,
|
|
202
|
+
use_cache=True,
|
|
203
|
+
cache_dir=args.cache_dir,
|
|
204
|
+
).to(args.target_device)
|
|
205
|
+
end_time = time.time()
|
|
206
|
+
|
|
207
|
+
if args.benchmark_type == "hf-pt-compile":
|
|
208
|
+
model = torch.compile(model)
|
|
209
|
+
|
|
210
|
+
elif args.benchmark_type in {"hf-ort", "ort-msft", "ort-convert-to-onnx"}:
|
|
211
|
+
sess_options = ort.SessionOptions()
|
|
212
|
+
sess_options.enable_profiling = args.profile
|
|
213
|
+
if args.verbose:
|
|
214
|
+
sess_options.log_verbosity_level = 1
|
|
215
|
+
sess_options.log_severity_level = 1
|
|
216
|
+
|
|
217
|
+
else:
|
|
218
|
+
raise Exception(f"Cannot recognize {args.benchmark_type}")
|
|
219
|
+
|
|
220
|
+
if args.benchmark_type == "hf-ort":
|
|
221
|
+
# Optimum export or convert_to_onnx.py export
|
|
222
|
+
provider = args.execution_provider[0] if type(args.execution_provider) is tuple else args.execution_provider
|
|
223
|
+
provider_options = args.execution_provider[1] if type(args.execution_provider) is tuple else None
|
|
224
|
+
|
|
225
|
+
decoder_file_name = None
|
|
226
|
+
decoder_with_past_file_name = None
|
|
227
|
+
for filename in os.listdir(args.hf_ort_dir_path):
|
|
228
|
+
if ".onnx" not in filename or ".onnx_data" in filename or ".onnx.data" in filename:
|
|
229
|
+
continue
|
|
230
|
+
if "decoder_model" in filename or filename == "model.onnx":
|
|
231
|
+
decoder_file_name = filename
|
|
232
|
+
if "decoder_with_past_model" in filename:
|
|
233
|
+
decoder_with_past_file_name = filename
|
|
234
|
+
if "decoder_merged_model" in filename:
|
|
235
|
+
decoder_file_name = filename
|
|
236
|
+
decoder_with_past_file_name = filename
|
|
237
|
+
|
|
238
|
+
start_time = time.time()
|
|
239
|
+
model = ORTModelForCausalLM.from_pretrained(
|
|
240
|
+
args.hf_ort_dir_path,
|
|
241
|
+
decoder_file_name=decoder_file_name,
|
|
242
|
+
decoder_with_past_file_name=decoder_with_past_file_name,
|
|
243
|
+
use_auth_token=args.auth,
|
|
244
|
+
trust_remote_code=args.auth,
|
|
245
|
+
use_io_binding=True, # Large perf gain even for cpu due to avoiding output copy.
|
|
246
|
+
use_merged=(True if decoder_file_name == "model.onnx" else None),
|
|
247
|
+
provider=provider,
|
|
248
|
+
provider_options=provider_options,
|
|
249
|
+
session_options=sess_options,
|
|
250
|
+
)
|
|
251
|
+
end_time = time.time()
|
|
252
|
+
|
|
253
|
+
if args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"}:
|
|
254
|
+
# Ex: Microsoft export from https://github.com/microsoft/Llama-2-Onnx
|
|
255
|
+
logger.info(f"Loading model from {args.ort_model_path.format(args.rank)}")
|
|
256
|
+
start_time = time.time()
|
|
257
|
+
model = ort.InferenceSession(
|
|
258
|
+
args.ort_model_path.format(args.rank),
|
|
259
|
+
sess_options,
|
|
260
|
+
providers=[args.execution_provider],
|
|
261
|
+
)
|
|
262
|
+
end_time = time.time()
|
|
263
|
+
|
|
264
|
+
logger.info(f"Loaded model in {end_time - start_time} s")
|
|
265
|
+
return model
|
|
266
|
+
|
|
267
|
+
|
|
268
|
+
def time_fn(args, fn, inputs):
|
|
269
|
+
# Warm up
|
|
270
|
+
warmup_range = (
|
|
271
|
+
range(args.warmup_runs)
|
|
272
|
+
if args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"}
|
|
273
|
+
else trange(args.warmup_runs, file=sys.stdout, desc="Warm up")
|
|
274
|
+
)
|
|
275
|
+
|
|
276
|
+
if args.verbose:
|
|
277
|
+
outputs = fn(inputs)
|
|
278
|
+
logger.info(outputs)
|
|
279
|
+
|
|
280
|
+
input_sync = lambda *kwargs: ( # noqa: E731
|
|
281
|
+
args.io_binding.synchronize_inputs()
|
|
282
|
+
if args.device != "cpu" and args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"} # ORT synchronize
|
|
283
|
+
else lambda *kwargs: (
|
|
284
|
+
torch.cuda.synchronize()
|
|
285
|
+
if args.device != "cpu" and torch.cuda.is_available() # PyTorch synchronize
|
|
286
|
+
else lambda *kwargs: None
|
|
287
|
+
)
|
|
288
|
+
) # no-op function
|
|
289
|
+
|
|
290
|
+
output_sync = lambda *kwargs: ( # noqa: E731
|
|
291
|
+
args.io_binding.synchronize_outputs()
|
|
292
|
+
if args.device != "cpu" and args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"} # ORT synchronize
|
|
293
|
+
else lambda *kwargs: (
|
|
294
|
+
torch.cuda.synchronize()
|
|
295
|
+
if args.device != "cpu" and torch.cuda.is_available() # PyTorch synchronize
|
|
296
|
+
else lambda *kwargs: None
|
|
297
|
+
)
|
|
298
|
+
) # no-op function
|
|
299
|
+
|
|
300
|
+
for _ in warmup_range:
|
|
301
|
+
input_sync()
|
|
302
|
+
fn(inputs)
|
|
303
|
+
output_sync()
|
|
304
|
+
|
|
305
|
+
# Benchmark
|
|
306
|
+
total_time = 0
|
|
307
|
+
bench_range = (
|
|
308
|
+
range(args.num_runs)
|
|
309
|
+
if args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"}
|
|
310
|
+
else trange(args.num_runs, file=sys.stdout, desc="Benchmark")
|
|
311
|
+
)
|
|
312
|
+
for _ in bench_range:
|
|
313
|
+
input_sync()
|
|
314
|
+
start_time = time.time()
|
|
315
|
+
|
|
316
|
+
fn(inputs)
|
|
317
|
+
|
|
318
|
+
output_sync()
|
|
319
|
+
end_time = time.time()
|
|
320
|
+
|
|
321
|
+
total_time += end_time - start_time
|
|
322
|
+
|
|
323
|
+
# Newline print after trange in order to print metrics on new lines without progress bar on same line
|
|
324
|
+
if args.benchmark_type not in {"ort-msft", "ort-convert-to-onnx"}:
|
|
325
|
+
logger.info("")
|
|
326
|
+
|
|
327
|
+
latency = total_time / args.num_runs
|
|
328
|
+
throughput = args.batch_size / latency
|
|
329
|
+
|
|
330
|
+
if args.rank == 0:
|
|
331
|
+
logger.info(f"Batch Size: {args.batch_size}")
|
|
332
|
+
logger.info(f"Sequence Length: {args.sequence_length}")
|
|
333
|
+
logger.info(f"Latency: {latency} s")
|
|
334
|
+
logger.info(f"Throughput: {throughput} tps")
|
|
335
|
+
return
|
|
336
|
+
|
|
337
|
+
|
|
338
|
+
def profile_fn(args, fn, inputs, inputs_type):
|
|
339
|
+
# Filename prefix format:
|
|
340
|
+
# "b<batch-size>_s<sequence-length>_<benchmark-type>-<precision>-<device>_<inference-step>_<inputs-type>_<current-time>"
|
|
341
|
+
prefix = f"b{args.batch_size}_s{args.sequence_length}_{args.benchmark_type.lower()}-{args.precision}-{args.device}_{fn.__name__.replace('_', '-')}_{inputs_type}_{datetime.datetime.now():%Y-%m-%d_%H:%M:%S}"
|
|
342
|
+
filename = None
|
|
343
|
+
|
|
344
|
+
if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile"}:
|
|
345
|
+
# Profile PyTorch kernels
|
|
346
|
+
with profile( # noqa: SIM117
|
|
347
|
+
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True, profile_memory=True
|
|
348
|
+
) as prof:
|
|
349
|
+
with record_function("model_inference"):
|
|
350
|
+
fn(inputs)
|
|
351
|
+
prof_data = prof.key_averages(group_by_stack_n=5).table(sort_by=args.pt_filter_by, row_limit=args.pt_num_rows)
|
|
352
|
+
|
|
353
|
+
filename = os.path.join(args.log_folder, f"{prefix}.log")
|
|
354
|
+
with open(filename, "w") as f:
|
|
355
|
+
f.write(prof_data)
|
|
356
|
+
|
|
357
|
+
else:
|
|
358
|
+
# Profile ORT kernels
|
|
359
|
+
fn(inputs)
|
|
360
|
+
|
|
361
|
+
# Set new log name for ORT profile log generated
|
|
362
|
+
filename = f"{prefix}.json"
|
|
363
|
+
|
|
364
|
+
return filename
|
|
365
|
+
|
|
366
|
+
|
|
367
|
+
def measure_fn(args, fn, inputs):
|
|
368
|
+
# Measure CPU usage
|
|
369
|
+
pid = os.getpid()
|
|
370
|
+
process = psutil.Process(pid)
|
|
371
|
+
process.cpu_percent(interval=0.1)
|
|
372
|
+
|
|
373
|
+
fn(inputs)
|
|
374
|
+
if args.rank == 0:
|
|
375
|
+
logger.info(f"CPU usage: {process.cpu_percent(interval=None) / psutil.cpu_count(logical=False)}%")
|
|
376
|
+
|
|
377
|
+
# Measure memory usage
|
|
378
|
+
gc.collect()
|
|
379
|
+
torch.cuda.empty_cache()
|
|
380
|
+
measure_memory(is_gpu=(args.device != "cpu"), func=lambda: fn(inputs))
|
|
381
|
+
|
|
382
|
+
# Flush output so memory usage is printed
|
|
383
|
+
sys.stdout.flush()
|
|
384
|
+
|
|
385
|
+
|
|
386
|
+
def run_hf_inference(args, init_inputs, iter_inputs, model):
|
|
387
|
+
# Inference steps to measure
|
|
388
|
+
def get_logits(inputs):
|
|
389
|
+
# Inference pass without decoding
|
|
390
|
+
outputs = model(**inputs)
|
|
391
|
+
return outputs
|
|
392
|
+
|
|
393
|
+
# Examples of other inference steps that can be measured:
|
|
394
|
+
# To use, uncomment the function and assign it to `generate_fn`
|
|
395
|
+
|
|
396
|
+
# def get_pred_ids(inputs):
|
|
397
|
+
# # Inference pass with predicted token ids generation
|
|
398
|
+
# predicted_ids = model.generate(**inputs)
|
|
399
|
+
# return predicted_ids
|
|
400
|
+
|
|
401
|
+
# def gen_and_dec(inputs):
|
|
402
|
+
# # Inference pass with generation and decoding
|
|
403
|
+
# predicted_ids = get_pred_ids(inputs)
|
|
404
|
+
# transcription = []
|
|
405
|
+
# for bs in range(args.batch_size):
|
|
406
|
+
# for rs in range(args.num_return_sequences):
|
|
407
|
+
# transcription.append(
|
|
408
|
+
# args.tokenizer.batch_decode(
|
|
409
|
+
# predicted_ids[bs * args.num_return_sequences + rs], skip_special_tokens=True
|
|
410
|
+
# )[0]
|
|
411
|
+
# )
|
|
412
|
+
# return transcription
|
|
413
|
+
|
|
414
|
+
generate_fn = get_logits
|
|
415
|
+
|
|
416
|
+
if args.benchmark_type == "hf-pt-compile":
|
|
417
|
+
# Run forward pass once with each set of inputs to process through Dynamo
|
|
418
|
+
generate_fn(init_inputs)
|
|
419
|
+
generate_fn(iter_inputs)
|
|
420
|
+
|
|
421
|
+
if args.profile:
|
|
422
|
+
new_logname = profile_fn(args, generate_fn, init_inputs, "prompt")
|
|
423
|
+
if args.benchmark_type == "hf-ort":
|
|
424
|
+
# Turn profiling off to stop appending to log
|
|
425
|
+
old_logname = model.decoder.session.end_profiling()
|
|
426
|
+
logger.warning(f"Renaming {old_logname} to {new_logname}")
|
|
427
|
+
os.rename(old_logname, os.path.join(args.log_folder, new_logname))
|
|
428
|
+
|
|
429
|
+
new_logname = profile_fn(args, generate_fn, iter_inputs, "token")
|
|
430
|
+
if args.benchmark_type == "hf-ort":
|
|
431
|
+
# Turn profiling off to stop appending to log
|
|
432
|
+
old_logname = model.decoder_with_past.session.end_profiling()
|
|
433
|
+
logger.warning(f"Renaming {old_logname} to {new_logname}")
|
|
434
|
+
os.rename(old_logname, os.path.join(args.log_folder, new_logname))
|
|
435
|
+
|
|
436
|
+
return
|
|
437
|
+
|
|
438
|
+
# PyTorch evaluations
|
|
439
|
+
logger.info("\nEvaluating `model(inputs)` step to get past_key_values")
|
|
440
|
+
time_fn(args, generate_fn, init_inputs)
|
|
441
|
+
measure_fn(args, generate_fn, init_inputs)
|
|
442
|
+
|
|
443
|
+
logger.info("\nEvaluating `model(inputs)` step with past_key_values")
|
|
444
|
+
time_fn(args, generate_fn, iter_inputs)
|
|
445
|
+
measure_fn(args, generate_fn, iter_inputs)
|
|
446
|
+
|
|
447
|
+
|
|
448
|
+
def run_ort_inference(args, init_inputs, iter_inputs, model):
|
|
449
|
+
def prepare_ort_inputs(inputs, kv_cache_ortvalues):
|
|
450
|
+
# Verify model inputs
|
|
451
|
+
inputs = verify_ort_inputs(model, inputs)
|
|
452
|
+
|
|
453
|
+
# Add IO bindings for non-CPU execution providers
|
|
454
|
+
if args.device != "cpu":
|
|
455
|
+
io_binding, kv_cache_ortvalues = add_io_bindings_as_ortvalues(
|
|
456
|
+
model, inputs, args.device, int(args.rank), args.use_buffer_share, kv_cache_ortvalues
|
|
457
|
+
)
|
|
458
|
+
setattr(args, "io_binding", io_binding) # noqa: B010
|
|
459
|
+
return io_binding, kv_cache_ortvalues
|
|
460
|
+
|
|
461
|
+
return inputs, kv_cache_ortvalues
|
|
462
|
+
|
|
463
|
+
def with_io_binding(io_binding):
|
|
464
|
+
# Inference pass with IO binding
|
|
465
|
+
model.run_with_iobinding(io_binding)
|
|
466
|
+
|
|
467
|
+
def without_io_binding(inputs):
|
|
468
|
+
# Inference pass without IO binding
|
|
469
|
+
outputs = model.run(None, inputs)
|
|
470
|
+
return outputs
|
|
471
|
+
|
|
472
|
+
generate_fn = with_io_binding if args.device != "cpu" else without_io_binding
|
|
473
|
+
kv_cache_ortvalues = {}
|
|
474
|
+
|
|
475
|
+
if args.profile:
|
|
476
|
+
ort_init_inputs, kv_cache_ortvalues = prepare_ort_inputs(init_inputs, kv_cache_ortvalues)
|
|
477
|
+
new_logname = profile_fn(args, generate_fn, ort_init_inputs, "prompt")
|
|
478
|
+
|
|
479
|
+
# Turn profiling off to stop appending to log file
|
|
480
|
+
old_logname = model.end_profiling()
|
|
481
|
+
logger.warning(f"Renaming {old_logname} to {new_logname}")
|
|
482
|
+
os.rename(old_logname, os.path.join(args.log_folder, new_logname))
|
|
483
|
+
|
|
484
|
+
# Re-initialize model for new log file instead of appending to old log file
|
|
485
|
+
model = get_model(args)
|
|
486
|
+
ort_iter_inputs, kv_cache_ortvalues = prepare_ort_inputs(iter_inputs, kv_cache_ortvalues)
|
|
487
|
+
new_logname = profile_fn(args, generate_fn, ort_iter_inputs, "token")
|
|
488
|
+
|
|
489
|
+
# Turn profiling off to stop appending to log
|
|
490
|
+
old_logname = model.end_profiling()
|
|
491
|
+
logger.warning(f"Renaming {old_logname} to {new_logname}")
|
|
492
|
+
os.rename(old_logname, os.path.join(args.log_folder, new_logname))
|
|
493
|
+
return
|
|
494
|
+
|
|
495
|
+
# ORT evaluations
|
|
496
|
+
logger.info("\nEvaluating `model(inputs)` step to get past_key_values")
|
|
497
|
+
ort_init_inputs, kv_cache_ortvalues = prepare_ort_inputs(init_inputs, kv_cache_ortvalues)
|
|
498
|
+
time_fn(args, generate_fn, ort_init_inputs)
|
|
499
|
+
measure_fn(args, generate_fn, ort_init_inputs)
|
|
500
|
+
|
|
501
|
+
logger.info("\nEvaluating `model(inputs)` step with past_key_values")
|
|
502
|
+
ort_iter_inputs, kv_cache_ortvalues = prepare_ort_inputs(iter_inputs, kv_cache_ortvalues)
|
|
503
|
+
time_fn(args, generate_fn, ort_iter_inputs)
|
|
504
|
+
measure_fn(args, generate_fn, ort_iter_inputs)
|
|
505
|
+
|
|
506
|
+
|
|
507
|
+
def run_inference(args, init_inputs, iter_inputs, model):
|
|
508
|
+
if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile", "hf-ort"}:
|
|
509
|
+
run_hf_inference(args, init_inputs, iter_inputs, model)
|
|
510
|
+
elif args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"}:
|
|
511
|
+
run_ort_inference(args, init_inputs, iter_inputs, model)
|
|
512
|
+
else:
|
|
513
|
+
raise Exception(f"Cannot recognize {args.benchmark_type}")
|
|
514
|
+
|
|
515
|
+
|
|
516
|
+
def get_args(rank=0):
|
|
517
|
+
parser = argparse.ArgumentParser()
|
|
518
|
+
parser.add_argument(
|
|
519
|
+
"-bt",
|
|
520
|
+
"--benchmark-type",
|
|
521
|
+
type=str,
|
|
522
|
+
required=True,
|
|
523
|
+
choices=[
|
|
524
|
+
"hf-pt-eager",
|
|
525
|
+
"hf-pt-compile",
|
|
526
|
+
"hf-ort",
|
|
527
|
+
"ort-msft",
|
|
528
|
+
"ort-convert-to-onnx",
|
|
529
|
+
],
|
|
530
|
+
)
|
|
531
|
+
parser.add_argument(
|
|
532
|
+
"-m",
|
|
533
|
+
"--model-name",
|
|
534
|
+
type=str,
|
|
535
|
+
required=True,
|
|
536
|
+
help="Hugging Face name of model (e.g. 'meta-llama/Llama-2-7b-hf')",
|
|
537
|
+
)
|
|
538
|
+
parser.add_argument(
|
|
539
|
+
"-a", "--auth", default=False, action="store_true", help="Use Hugging Face authentication token to access model"
|
|
540
|
+
)
|
|
541
|
+
|
|
542
|
+
# Args for choosing the model
|
|
543
|
+
parser.add_argument(
|
|
544
|
+
"-p",
|
|
545
|
+
"--precision",
|
|
546
|
+
required=True,
|
|
547
|
+
type=str,
|
|
548
|
+
default="fp32",
|
|
549
|
+
choices=["int4", "int8", "fp16", "fp32"],
|
|
550
|
+
help="Precision for model. For ONNX models, the model's precision should be set before running this script.",
|
|
551
|
+
)
|
|
552
|
+
parser.add_argument(
|
|
553
|
+
"--hf-pt-dir-path",
|
|
554
|
+
type=str,
|
|
555
|
+
default="",
|
|
556
|
+
help="Path to directory containing all PyTorch files (e.g. tokenizer, PyTorch model)",
|
|
557
|
+
)
|
|
558
|
+
parser.add_argument(
|
|
559
|
+
"--hf-ort-dir-path",
|
|
560
|
+
type=str,
|
|
561
|
+
default="",
|
|
562
|
+
help="Path to directory containing all ONNX files (e.g. tokenizer, decoder_merged, decoder, decoder_with_past)",
|
|
563
|
+
)
|
|
564
|
+
parser.add_argument(
|
|
565
|
+
"--ort-model-path",
|
|
566
|
+
type=str,
|
|
567
|
+
default="",
|
|
568
|
+
help="Path to ONNX model",
|
|
569
|
+
)
|
|
570
|
+
|
|
571
|
+
# Args for running and evaluating the model
|
|
572
|
+
parser.add_argument(
|
|
573
|
+
"-b",
|
|
574
|
+
"--batch-sizes",
|
|
575
|
+
default="1 2",
|
|
576
|
+
)
|
|
577
|
+
parser.add_argument(
|
|
578
|
+
"-s",
|
|
579
|
+
"--sequence-lengths",
|
|
580
|
+
default="32 64 128 256 512",
|
|
581
|
+
)
|
|
582
|
+
parser.add_argument(
|
|
583
|
+
"-d",
|
|
584
|
+
"--device",
|
|
585
|
+
type=str,
|
|
586
|
+
default="cuda" if torch.cuda.is_available() else "cpu",
|
|
587
|
+
choices=["cpu", "cuda"],
|
|
588
|
+
)
|
|
589
|
+
parser.add_argument("-id", "--device-id", type=int, default=0)
|
|
590
|
+
parser.add_argument("-w", "--warmup-runs", type=int, default=5)
|
|
591
|
+
parser.add_argument("-n", "--num-runs", type=int, default=10)
|
|
592
|
+
parser.add_argument("--seed", type=int, default=2)
|
|
593
|
+
|
|
594
|
+
# Args for decoding logic
|
|
595
|
+
parser.add_argument("--max-length", type=int, default=32)
|
|
596
|
+
parser.add_argument("--num-return-sequences", type=int, default=1)
|
|
597
|
+
|
|
598
|
+
# Args for accessing detailed info
|
|
599
|
+
parser.add_argument("--profile", default=False, action="store_true")
|
|
600
|
+
parser.add_argument(
|
|
601
|
+
"--pt-filter-by", type=str, default="self_cpu_time_total", help="What to filter PyTorch profiler by"
|
|
602
|
+
)
|
|
603
|
+
parser.add_argument("--pt-num-rows", type=int, default=1000, help="Number of rows for PyTorch profiler to display")
|
|
604
|
+
parser.add_argument("--verbose", default=False, action="store_true")
|
|
605
|
+
parser.add_argument("--log-folder", type=str, default=os.path.join("."), help="Folder to cache log files")
|
|
606
|
+
parser.add_argument(
|
|
607
|
+
"--cache-dir",
|
|
608
|
+
type=str,
|
|
609
|
+
required=True,
|
|
610
|
+
default="./model_cache",
|
|
611
|
+
help="Cache dir where Hugging Face files are stored",
|
|
612
|
+
)
|
|
613
|
+
|
|
614
|
+
args = parser.parse_args()
|
|
615
|
+
|
|
616
|
+
# Set seed properties
|
|
617
|
+
np.random.seed(args.seed)
|
|
618
|
+
torch.manual_seed(args.seed)
|
|
619
|
+
|
|
620
|
+
# Set runtime properties
|
|
621
|
+
if "ort" in args.benchmark_type:
|
|
622
|
+
setattr(args, "execution_provider", f"{args.device.upper()}ExecutionProvider") # noqa: B010
|
|
623
|
+
if args.execution_provider == "CUDAExecutionProvider":
|
|
624
|
+
args.execution_provider = (args.execution_provider, {"device_id": rank})
|
|
625
|
+
|
|
626
|
+
# Check that paths have been specified for any benchmarking with ORT
|
|
627
|
+
if args.benchmark_type == "hf-ort":
|
|
628
|
+
assert args.hf_ort_dir_path, "Please specify a path to `--hf-ort-dir-path`"
|
|
629
|
+
if args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"}:
|
|
630
|
+
assert args.ort_model_path, "Please specify a path to `--ort-model-path`"
|
|
631
|
+
|
|
632
|
+
args.batch_sizes = args.batch_sizes.split(" ")
|
|
633
|
+
args.sequence_lengths = args.sequence_lengths.split(" ")
|
|
634
|
+
|
|
635
|
+
# Use FP32 precision for FP32, INT8, INT4 CPU models, use FP16 precision for FP16 and INT4 GPU models
|
|
636
|
+
args.precision = (
|
|
637
|
+
"fp32" if args.precision in {"int8", "fp32"} or (args.precision == "int4" and args.device == "cpu") else "fp16"
|
|
638
|
+
)
|
|
639
|
+
|
|
640
|
+
# Check that only one (batch_size, sequence_length) combination is set for profiling
|
|
641
|
+
if args.profile:
|
|
642
|
+
assert len(args.batch_sizes) == 1 and len(args.sequence_lengths) == 1, (
|
|
643
|
+
"Please provide only one (batch_size, sequence_length) combination for profiling"
|
|
644
|
+
)
|
|
645
|
+
|
|
646
|
+
return args
|
|
647
|
+
|
|
648
|
+
|
|
649
|
+
def main():
|
|
650
|
+
rank = get_rank()
|
|
651
|
+
world_size = get_size()
|
|
652
|
+
|
|
653
|
+
args = get_args(rank)
|
|
654
|
+
setup_logger(args.verbose)
|
|
655
|
+
logger.info(args.__dict__)
|
|
656
|
+
torch.backends.cudnn.benchmark = True
|
|
657
|
+
|
|
658
|
+
args.rank = rank
|
|
659
|
+
args.world_size = world_size
|
|
660
|
+
tokenizer = AutoTokenizer.from_pretrained(
|
|
661
|
+
args.model_name, cache_dir=args.cache_dir, use_auth_token=args.auth, trust_remote_code=args.auth
|
|
662
|
+
)
|
|
663
|
+
config = AutoConfig.from_pretrained(
|
|
664
|
+
args.model_name, cache_dir=args.cache_dir, use_auth_token=args.auth, trust_remote_code=args.auth
|
|
665
|
+
)
|
|
666
|
+
target_device = f"cuda:{args.rank}" if args.device != "cpu" else args.device
|
|
667
|
+
use_fp16 = args.precision == "fp16"
|
|
668
|
+
|
|
669
|
+
setattr(args, "tokenizer", tokenizer) # noqa: B010
|
|
670
|
+
setattr(args, "config", config) # noqa: B010
|
|
671
|
+
setattr(args, "target_device", target_device) # noqa: B010
|
|
672
|
+
setattr(args, "use_fp16", use_fp16) # noqa: B010
|
|
673
|
+
|
|
674
|
+
# Get model and model info
|
|
675
|
+
model = get_model(args)
|
|
676
|
+
ort_model_inputs_len = get_ort_model_inputs_len(args, model)
|
|
677
|
+
|
|
678
|
+
# Check if past_present_share_buffer can be enabled (only for FP16 models with GQA)
|
|
679
|
+
if args.benchmark_type in {"ort-convert-to-onnx", "ort-msft"}:
|
|
680
|
+
onnx_model = onnx.load_model(args.ort_model_path.format(args.rank), load_external_data=False)
|
|
681
|
+
gqa_nodes = list(filter(lambda node: node.op_type == "GroupQueryAttention", onnx_model.graph.node))
|
|
682
|
+
|
|
683
|
+
use_buffer_share = use_fp16 and len(gqa_nodes) > 0 and args.device != "cpu"
|
|
684
|
+
setattr(args, "use_buffer_share", use_buffer_share) # noqa: B010
|
|
685
|
+
else:
|
|
686
|
+
setattr(args, "use_buffer_share", False) # noqa: B010
|
|
687
|
+
|
|
688
|
+
# Measure prompt cost (init_inputs) and generated token cost (iter_inputs)
|
|
689
|
+
for batch_size, sequence_length in itertools.product(args.batch_sizes, args.sequence_lengths):
|
|
690
|
+
if args.rank == 0:
|
|
691
|
+
logger.info(f"\nBatch size = {batch_size} and sequence length = {sequence_length}...")
|
|
692
|
+
setattr(args, "batch_size", int(batch_size)) # noqa: B010
|
|
693
|
+
setattr(args, "sequence_length", int(sequence_length)) # noqa: B010
|
|
694
|
+
|
|
695
|
+
init_inputs, iter_inputs = get_inputs(args, ort_model_inputs_len)
|
|
696
|
+
run_inference(args, init_inputs, iter_inputs, model)
|
|
697
|
+
|
|
698
|
+
|
|
699
|
+
if __name__ == "__main__":
|
|
700
|
+
main()
|