onnxruntime-directml 1.20.0__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxruntime/LICENSE +21 -0
- onnxruntime/Privacy.md +21 -0
- onnxruntime/ThirdPartyNotices.txt +6508 -0
- onnxruntime/__init__.py +78 -0
- onnxruntime/backend/__init__.py +6 -0
- onnxruntime/backend/backend.py +174 -0
- onnxruntime/backend/backend_rep.py +53 -0
- onnxruntime/capi/DirectML.dll +0 -0
- onnxruntime/capi/__init__.py +4 -0
- onnxruntime/capi/_ld_preload.py +7 -0
- onnxruntime/capi/_pybind_state.py +33 -0
- onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
- onnxruntime/capi/onnxruntime.dll +0 -0
- onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
- onnxruntime/capi/onnxruntime_inference_collection.py +1108 -0
- onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
- onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
- onnxruntime/capi/onnxruntime_validation.py +150 -0
- onnxruntime/capi/version_info.py +2 -0
- onnxruntime/datasets/__init__.py +17 -0
- onnxruntime/datasets/logreg_iris.onnx +0 -0
- onnxruntime/datasets/mul_1.onnx +0 -0
- onnxruntime/datasets/sigmoid.onnx +13 -0
- onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
- onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
- onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
- onnxruntime/quantization/__init__.py +16 -0
- onnxruntime/quantization/base_quantizer.py +532 -0
- onnxruntime/quantization/calibrate.py +1245 -0
- onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
- onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
- onnxruntime/quantization/execution_providers/qnn/preprocess.py +307 -0
- onnxruntime/quantization/execution_providers/qnn/quant_config.py +387 -0
- onnxruntime/quantization/fusions/__init__.py +3 -0
- onnxruntime/quantization/fusions/fusion.py +311 -0
- onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
- onnxruntime/quantization/fusions/fusion_layernorm.py +135 -0
- onnxruntime/quantization/matmul_4bits_quantizer.py +1480 -0
- onnxruntime/quantization/matmul_bnb4_quantizer.py +240 -0
- onnxruntime/quantization/onnx_model.py +580 -0
- onnxruntime/quantization/onnx_quantizer.py +1008 -0
- onnxruntime/quantization/operators/__init__.py +2 -0
- onnxruntime/quantization/operators/activation.py +119 -0
- onnxruntime/quantization/operators/argmax.py +18 -0
- onnxruntime/quantization/operators/attention.py +73 -0
- onnxruntime/quantization/operators/base_operator.py +26 -0
- onnxruntime/quantization/operators/binary_op.py +72 -0
- onnxruntime/quantization/operators/concat.py +62 -0
- onnxruntime/quantization/operators/conv.py +258 -0
- onnxruntime/quantization/operators/direct_q8.py +78 -0
- onnxruntime/quantization/operators/embed_layernorm.py +121 -0
- onnxruntime/quantization/operators/gather.py +64 -0
- onnxruntime/quantization/operators/gavgpool.py +62 -0
- onnxruntime/quantization/operators/gemm.py +166 -0
- onnxruntime/quantization/operators/lstm.py +117 -0
- onnxruntime/quantization/operators/matmul.py +231 -0
- onnxruntime/quantization/operators/maxpool.py +34 -0
- onnxruntime/quantization/operators/norm.py +40 -0
- onnxruntime/quantization/operators/pad.py +100 -0
- onnxruntime/quantization/operators/pooling.py +67 -0
- onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
- onnxruntime/quantization/operators/resize.py +34 -0
- onnxruntime/quantization/operators/softmax.py +74 -0
- onnxruntime/quantization/operators/split.py +63 -0
- onnxruntime/quantization/operators/where.py +87 -0
- onnxruntime/quantization/preprocess.py +141 -0
- onnxruntime/quantization/qdq_loss_debug.py +389 -0
- onnxruntime/quantization/qdq_quantizer.py +1187 -0
- onnxruntime/quantization/quant_utils.py +891 -0
- onnxruntime/quantization/quantize.py +748 -0
- onnxruntime/quantization/registry.py +106 -0
- onnxruntime/quantization/shape_inference.py +187 -0
- onnxruntime/quantization/tensor_quant_overrides.py +516 -0
- onnxruntime/tools/__init__.py +10 -0
- onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
- onnxruntime/tools/convert_onnx_models_to_ort.py +377 -0
- onnxruntime/tools/file_utils.py +46 -0
- onnxruntime/tools/logger.py +11 -0
- onnxruntime/tools/make_dynamic_shape_fixed.py +72 -0
- onnxruntime/tools/mobile_helpers/__init__.py +0 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +33 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
- onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
- onnxruntime/tools/mobile_helpers/usability_checker.py +739 -0
- onnxruntime/tools/offline_tuning.py +169 -0
- onnxruntime/tools/onnx_model_utils.py +413 -0
- onnxruntime/tools/onnx_randomizer.py +85 -0
- onnxruntime/tools/onnxruntime_test.py +164 -0
- onnxruntime/tools/optimize_onnx_model.py +55 -0
- onnxruntime/tools/ort_format_model/__init__.py +25 -0
- onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +663 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
- onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
- onnxruntime/tools/ort_format_model/types.py +84 -0
- onnxruntime/tools/ort_format_model/utils.py +62 -0
- onnxruntime/tools/pytorch_export_contrib_ops.py +108 -0
- onnxruntime/tools/pytorch_export_helpers.py +131 -0
- onnxruntime/tools/qdq_helpers/__init__.py +0 -0
- onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
- onnxruntime/tools/reduced_build_config_parser.py +202 -0
- onnxruntime/tools/symbolic_shape_infer.py +3016 -0
- onnxruntime/tools/update_onnx_opset.py +31 -0
- onnxruntime/transformers/__init__.py +8 -0
- onnxruntime/transformers/affinity_helper.py +40 -0
- onnxruntime/transformers/benchmark.py +944 -0
- onnxruntime/transformers/benchmark_helper.py +646 -0
- onnxruntime/transformers/bert_perf_test.py +634 -0
- onnxruntime/transformers/bert_test_data.py +642 -0
- onnxruntime/transformers/compare_bert_results.py +246 -0
- onnxruntime/transformers/constants.py +47 -0
- onnxruntime/transformers/convert_generation.py +3124 -0
- onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
- onnxruntime/transformers/convert_to_packing_mode.py +387 -0
- onnxruntime/transformers/dynamo_onnx_helper.py +104 -0
- onnxruntime/transformers/float16.py +501 -0
- onnxruntime/transformers/fusion_attention.py +1235 -0
- onnxruntime/transformers/fusion_attention_clip.py +257 -0
- onnxruntime/transformers/fusion_attention_sam2.py +534 -0
- onnxruntime/transformers/fusion_attention_unet.py +1304 -0
- onnxruntime/transformers/fusion_attention_vae.py +301 -0
- onnxruntime/transformers/fusion_bart_attention.py +640 -0
- onnxruntime/transformers/fusion_base.py +137 -0
- onnxruntime/transformers/fusion_bias_add.py +58 -0
- onnxruntime/transformers/fusion_biasgelu.py +66 -0
- onnxruntime/transformers/fusion_biassplitgelu.py +111 -0
- onnxruntime/transformers/fusion_conformer_attention.py +143 -0
- onnxruntime/transformers/fusion_embedlayer.py +811 -0
- onnxruntime/transformers/fusion_fastgelu.py +360 -0
- onnxruntime/transformers/fusion_gelu.py +259 -0
- onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
- onnxruntime/transformers/fusion_gemmfastgelu.py +122 -0
- onnxruntime/transformers/fusion_gpt_attention.py +546 -0
- onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
- onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
- onnxruntime/transformers/fusion_group_norm.py +179 -0
- onnxruntime/transformers/fusion_layernorm.py +465 -0
- onnxruntime/transformers/fusion_nhwc_conv.py +100 -0
- onnxruntime/transformers/fusion_options.py +340 -0
- onnxruntime/transformers/fusion_qordered_attention.py +421 -0
- onnxruntime/transformers/fusion_qordered_gelu.py +119 -0
- onnxruntime/transformers/fusion_qordered_layernorm.py +123 -0
- onnxruntime/transformers/fusion_qordered_matmul.py +217 -0
- onnxruntime/transformers/fusion_quickgelu.py +74 -0
- onnxruntime/transformers/fusion_reshape.py +173 -0
- onnxruntime/transformers/fusion_rotary_attention.py +1592 -0
- onnxruntime/transformers/fusion_shape.py +110 -0
- onnxruntime/transformers/fusion_simplified_layernorm.py +159 -0
- onnxruntime/transformers/fusion_skip_group_norm.py +255 -0
- onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
- onnxruntime/transformers/fusion_transpose.py +168 -0
- onnxruntime/transformers/fusion_utils.py +307 -0
- onnxruntime/transformers/huggingface_models.py +167 -0
- onnxruntime/transformers/import_utils.py +20 -0
- onnxruntime/transformers/io_binding_helper.py +442 -0
- onnxruntime/transformers/large_model_exporter.py +395 -0
- onnxruntime/transformers/machine_info.py +221 -0
- onnxruntime/transformers/metrics.py +164 -0
- onnxruntime/transformers/models/bart/__init__.py +12 -0
- onnxruntime/transformers/models/bart/export.py +98 -0
- onnxruntime/transformers/models/bert/__init__.py +12 -0
- onnxruntime/transformers/models/bert/eval_squad.py +329 -0
- onnxruntime/transformers/models/gpt2/__init__.py +12 -0
- onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
- onnxruntime/transformers/models/gpt2/convert_to_onnx.py +561 -0
- onnxruntime/transformers/models/gpt2/gpt2_helper.py +1032 -0
- onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
- onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
- onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
- onnxruntime/transformers/models/llama/__init__.py +12 -0
- onnxruntime/transformers/models/llama/benchmark.py +703 -0
- onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
- onnxruntime/transformers/models/llama/benchmark_e2e.py +606 -0
- onnxruntime/transformers/models/llama/convert_to_onnx.py +1027 -0
- onnxruntime/transformers/models/llama/dist_settings.py +57 -0
- onnxruntime/transformers/models/llama/llama_inputs.py +503 -0
- onnxruntime/transformers/models/llama/llama_parity.py +309 -0
- onnxruntime/transformers/models/llama/llama_torch.py +47 -0
- onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
- onnxruntime/transformers/models/longformer/__init__.py +12 -0
- onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
- onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
- onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
- onnxruntime/transformers/models/longformer/longformer_helper.py +77 -0
- onnxruntime/transformers/models/phi2/__init__.py +12 -0
- onnxruntime/transformers/models/phi2/convert_to_onnx.py +576 -0
- onnxruntime/transformers/models/phi2/inference_example.py +414 -0
- onnxruntime/transformers/models/sam2/__init__.py +12 -0
- onnxruntime/transformers/models/sam2/benchmark_sam2.py +625 -0
- onnxruntime/transformers/models/sam2/convert_to_onnx.py +260 -0
- onnxruntime/transformers/models/sam2/image_decoder.py +273 -0
- onnxruntime/transformers/models/sam2/image_encoder.py +186 -0
- onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
- onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
- onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
- onnxruntime/transformers/models/sam2/sam2_demo.py +322 -0
- onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +280 -0
- onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
- onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark.py +1429 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +102 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +268 -0
- onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1319 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1181 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder.py +296 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +388 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
- onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +350 -0
- onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
- onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
- onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
- onnxruntime/transformers/models/t5/__init__.py +12 -0
- onnxruntime/transformers/models/t5/convert_to_onnx.py +278 -0
- onnxruntime/transformers/models/t5/past_helper.py +150 -0
- onnxruntime/transformers/models/t5/t5_decoder.py +438 -0
- onnxruntime/transformers/models/t5/t5_encoder.py +171 -0
- onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +299 -0
- onnxruntime/transformers/models/t5/t5_helper.py +272 -0
- onnxruntime/transformers/models/whisper/__init__.py +12 -0
- onnxruntime/transformers/models/whisper/benchmark.py +610 -0
- onnxruntime/transformers/models/whisper/benchmark_all.py +528 -0
- onnxruntime/transformers/models/whisper/convert_to_onnx.py +536 -0
- onnxruntime/transformers/models/whisper/whisper_chain.py +329 -0
- onnxruntime/transformers/models/whisper/whisper_decoder.py +402 -0
- onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
- onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +306 -0
- onnxruntime/transformers/models/whisper/whisper_helper.py +524 -0
- onnxruntime/transformers/models/whisper/whisper_openai_helper.py +84 -0
- onnxruntime/transformers/onnx_exporter.py +717 -0
- onnxruntime/transformers/onnx_model.py +1569 -0
- onnxruntime/transformers/onnx_model_bart.py +142 -0
- onnxruntime/transformers/onnx_model_bert.py +481 -0
- onnxruntime/transformers/onnx_model_bert_keras.py +475 -0
- onnxruntime/transformers/onnx_model_bert_tf.py +589 -0
- onnxruntime/transformers/onnx_model_clip.py +40 -0
- onnxruntime/transformers/onnx_model_conformer.py +33 -0
- onnxruntime/transformers/onnx_model_gpt2.py +101 -0
- onnxruntime/transformers/onnx_model_phi.py +930 -0
- onnxruntime/transformers/onnx_model_sam2.py +138 -0
- onnxruntime/transformers/onnx_model_t5.py +791 -0
- onnxruntime/transformers/onnx_model_tnlr.py +227 -0
- onnxruntime/transformers/onnx_model_unet.py +259 -0
- onnxruntime/transformers/onnx_model_vae.py +43 -0
- onnxruntime/transformers/onnx_utils.py +55 -0
- onnxruntime/transformers/optimizer.py +612 -0
- onnxruntime/transformers/profiler.py +725 -0
- onnxruntime/transformers/quantize_helper.py +76 -0
- onnxruntime/transformers/shape_infer_helper.py +122 -0
- onnxruntime/transformers/shape_optimizer.py +401 -0
- onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
- onnxruntime_directml-1.20.0.dist-info/METADATA +187 -0
- onnxruntime_directml-1.20.0.dist-info/RECORD +305 -0
- onnxruntime_directml-1.20.0.dist-info/WHEEL +5 -0
- onnxruntime_directml-1.20.0.dist-info/entry_points.txt +2 -0
- onnxruntime_directml-1.20.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,606 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License. See License.txt in the project root for
|
|
4
|
+
# license information.
|
|
5
|
+
# --------------------------------------------------------------------------
|
|
6
|
+
|
|
7
|
+
# This is an end-to-end benchmarking script for the Hugging Face LLaMA-2 model.
|
|
8
|
+
#
|
|
9
|
+
# Prerequisites:
|
|
10
|
+
# 1) Install `huggingface-cli`:
|
|
11
|
+
#
|
|
12
|
+
# $ pip install huggingface_hub
|
|
13
|
+
#
|
|
14
|
+
# 2) Authenticate with Hugging Face's CLI:
|
|
15
|
+
#
|
|
16
|
+
# $ huggingface-cli login
|
|
17
|
+
#
|
|
18
|
+
# 3) Accept Meta's license in Hugging Face to access the models at https://huggingface.co/meta-llama/
|
|
19
|
+
#
|
|
20
|
+
# 4) Install the latest ONNX Runtime version
|
|
21
|
+
#
|
|
22
|
+
# $ pip install onnxruntime-gpu
|
|
23
|
+
#
|
|
24
|
+
# 5) Install flash attention v2
|
|
25
|
+
#
|
|
26
|
+
# $ pip install flash-attn --no-build-isolation
|
|
27
|
+
#
|
|
28
|
+
# 6) Install bitsandbytes
|
|
29
|
+
#
|
|
30
|
+
# $ pip install bitsandbytes
|
|
31
|
+
|
|
32
|
+
from __future__ import annotations
|
|
33
|
+
|
|
34
|
+
import argparse
|
|
35
|
+
import datetime
|
|
36
|
+
import gc
|
|
37
|
+
import itertools
|
|
38
|
+
import json
|
|
39
|
+
import logging
|
|
40
|
+
import os
|
|
41
|
+
import textwrap
|
|
42
|
+
import time
|
|
43
|
+
|
|
44
|
+
import numpy as np
|
|
45
|
+
import pandas as pd
|
|
46
|
+
import torch
|
|
47
|
+
from benchmark_helper import setup_logger
|
|
48
|
+
from llama_inputs import add_io_bindings_as_tensors, get_initial_inputs_and_outputs
|
|
49
|
+
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
|
50
|
+
|
|
51
|
+
import onnxruntime as ort
|
|
52
|
+
|
|
53
|
+
logger = logging.getLogger(__name__)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def get_model(args: argparse.Namespace):
|
|
57
|
+
if args.benchmark_type in {"pt-eager", "pt-compile"}:
|
|
58
|
+
model = None
|
|
59
|
+
if args.onnx_precision == "int4" and args.device == "cuda":
|
|
60
|
+
bnb_config = BitsAndBytesConfig(
|
|
61
|
+
load_in_4bit=True,
|
|
62
|
+
bnb_4bit_use_double_quant=True,
|
|
63
|
+
bnb_4bit_quant_type="nf4",
|
|
64
|
+
bnb_4bit_compute_dtype=torch.float16,
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
model = AutoModelForCausalLM.from_pretrained(
|
|
68
|
+
args.hf_dir_path if args.hf_dir_path != "" else args.model_name,
|
|
69
|
+
cache_dir=args.cache_dir,
|
|
70
|
+
torch_dtype=args.torch_dtype,
|
|
71
|
+
use_auth_token=args.auth,
|
|
72
|
+
trust_remote_code=args.trust,
|
|
73
|
+
use_cache=True,
|
|
74
|
+
attn_implementation="flash_attention_2",
|
|
75
|
+
quantization_config=bnb_config,
|
|
76
|
+
max_memory={args.device_id: "80GB"},
|
|
77
|
+
)
|
|
78
|
+
else:
|
|
79
|
+
try:
|
|
80
|
+
model = AutoModelForCausalLM.from_pretrained(
|
|
81
|
+
args.hf_dir_path if args.hf_dir_path != "" else args.model_name,
|
|
82
|
+
cache_dir=args.cache_dir,
|
|
83
|
+
torch_dtype=args.torch_dtype,
|
|
84
|
+
use_auth_token=args.auth,
|
|
85
|
+
trust_remote_code=args.trust,
|
|
86
|
+
use_cache=True,
|
|
87
|
+
attn_implementation=("flash_attention_2" if args.device == "cuda" else "sdpa"),
|
|
88
|
+
).to(args.target_device)
|
|
89
|
+
except Exception as e:
|
|
90
|
+
# When flash_attention or sdpa doesn't support a model, it throws an exception.
|
|
91
|
+
# Rather than stopping a process, run as eager mode.
|
|
92
|
+
print("Try to load a model using eager mode: ", e)
|
|
93
|
+
model = AutoModelForCausalLM.from_pretrained(
|
|
94
|
+
args.hf_dir_path if args.hf_dir_path != "" else args.model_name,
|
|
95
|
+
cache_dir=args.cache_dir,
|
|
96
|
+
torch_dtype=args.torch_dtype,
|
|
97
|
+
use_auth_token=args.auth,
|
|
98
|
+
trust_remote_code=args.trust,
|
|
99
|
+
use_cache=True,
|
|
100
|
+
attn_implementation="eager",
|
|
101
|
+
).to(args.target_device)
|
|
102
|
+
|
|
103
|
+
model.eval()
|
|
104
|
+
|
|
105
|
+
if args.benchmark_type == "pt-compile":
|
|
106
|
+
model = torch.compile(model)
|
|
107
|
+
|
|
108
|
+
else:
|
|
109
|
+
sess_options = ort.SessionOptions()
|
|
110
|
+
ep = (
|
|
111
|
+
("CUDAExecutionProvider", {"device_id": args.device_id})
|
|
112
|
+
if args.device == "cuda"
|
|
113
|
+
else "CPUExecutionProvider"
|
|
114
|
+
)
|
|
115
|
+
model = ort.InferenceSession(args.onnx_model_path, sess_options=sess_options, providers=[ep])
|
|
116
|
+
|
|
117
|
+
return model
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def run_inference(args, model, runs, inputs, outputs):
|
|
121
|
+
if args.benchmark_type == "pt-compile":
|
|
122
|
+
with torch.no_grad():
|
|
123
|
+
outputs = model(**inputs)
|
|
124
|
+
|
|
125
|
+
# Synchronize inputs
|
|
126
|
+
io_binding = None
|
|
127
|
+
if args.benchmark_type in {"pt-eager", "pt-compile"}:
|
|
128
|
+
if args.device != "cpu":
|
|
129
|
+
torch.cuda.synchronize(args.target_device)
|
|
130
|
+
else:
|
|
131
|
+
io_binding = add_io_bindings_as_tensors(model, inputs, outputs, args.use_fp16, args.use_buffer_share)
|
|
132
|
+
io_binding.synchronize_inputs()
|
|
133
|
+
|
|
134
|
+
# Run inference
|
|
135
|
+
start = time.perf_counter()
|
|
136
|
+
for _ in range(runs):
|
|
137
|
+
if args.benchmark_type in {"pt-eager", "pt-compile"}:
|
|
138
|
+
with torch.no_grad():
|
|
139
|
+
outputs = model(**inputs)
|
|
140
|
+
if args.device != "cpu":
|
|
141
|
+
torch.cuda.synchronize(args.target_device)
|
|
142
|
+
else:
|
|
143
|
+
model.run_with_iobinding(io_binding)
|
|
144
|
+
io_binding.synchronize_outputs()
|
|
145
|
+
|
|
146
|
+
end = time.perf_counter()
|
|
147
|
+
avg = (end - start) / runs
|
|
148
|
+
return avg, outputs
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def prepare_model_for_inference(args, model, config, tokenizer, prompt_length, prompt):
|
|
152
|
+
clear_cache()
|
|
153
|
+
inputs, outputs = get_initial_inputs_and_outputs(
|
|
154
|
+
config, tokenizer, prompt_length, prompt, args.target_device, args.use_fp16, args.use_buffer_share, args.engine
|
|
155
|
+
)
|
|
156
|
+
_, outputs = run_inference(args, model, args.warmup_runs, inputs, outputs)
|
|
157
|
+
return inputs, outputs
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
def clear_cache():
|
|
161
|
+
gc.collect()
|
|
162
|
+
torch.cuda.empty_cache()
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
def save_results(results, filename, gen_length):
|
|
166
|
+
df = pd.DataFrame(
|
|
167
|
+
results,
|
|
168
|
+
columns=[
|
|
169
|
+
"Batch Size",
|
|
170
|
+
"Prompt Length",
|
|
171
|
+
"Prompt Processing Latency (ms)",
|
|
172
|
+
"Prompt Processing Throughput (tps)",
|
|
173
|
+
"Sampling Latency (ms)",
|
|
174
|
+
"Sampling Throughput (tps)",
|
|
175
|
+
"First Token Generated Latency (ms)",
|
|
176
|
+
"First Token Generated Throughput (tps)",
|
|
177
|
+
f"Average Latency of First {gen_length // 2} Tokens Generated (ms)",
|
|
178
|
+
f"Average Throughput of First {gen_length // 2} Tokens Generated (tps)",
|
|
179
|
+
f"Average Latency of First {gen_length} Tokens Generated (ms)",
|
|
180
|
+
f"Average Throughput of First {gen_length} Tokens Generated (tps)",
|
|
181
|
+
"Wall-Clock Latency (s)",
|
|
182
|
+
"Wall-Clock Throughput (tps)",
|
|
183
|
+
],
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
df.to_csv(filename, index=False)
|
|
187
|
+
logger.info(f"Results saved in {filename}!")
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
def get_args():
|
|
191
|
+
parser = argparse.ArgumentParser()
|
|
192
|
+
|
|
193
|
+
parser.add_argument(
|
|
194
|
+
"-bt",
|
|
195
|
+
"--benchmark-type",
|
|
196
|
+
type=str,
|
|
197
|
+
required=True,
|
|
198
|
+
choices=["pt-eager", "pt-compile", "ort"],
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
parser.add_argument(
|
|
202
|
+
"-m",
|
|
203
|
+
"--model-name",
|
|
204
|
+
type=str,
|
|
205
|
+
required=False,
|
|
206
|
+
help="Hugging Face name of model (e.g. 'meta-llama/Llama-2-7b-hf')",
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
parser.add_argument(
|
|
210
|
+
"-a",
|
|
211
|
+
"--auth",
|
|
212
|
+
default=False,
|
|
213
|
+
action="store_true",
|
|
214
|
+
help="Use Hugging Face authentication token to access model",
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
parser.add_argument(
|
|
218
|
+
"-t",
|
|
219
|
+
"--trust",
|
|
220
|
+
default=False,
|
|
221
|
+
action="store_true",
|
|
222
|
+
help="Whether or not to allow for custom models defined on the Hugging Face Hub in their own modeling files",
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
parser.add_argument(
|
|
226
|
+
"-c",
|
|
227
|
+
"--cache-dir",
|
|
228
|
+
type=str,
|
|
229
|
+
default=os.path.join(".", "model_cache"),
|
|
230
|
+
help="Path to directory containing all Hugging Face files (e.g. config, tokenizer, PyTorch model). Use when loading model as `AutoModel.from_pretrained(model_name, cache_dir=cache_dir)`.",
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
parser.add_argument(
|
|
234
|
+
"--hf-dir-path",
|
|
235
|
+
type=str,
|
|
236
|
+
default="",
|
|
237
|
+
help="Path to directory containing all Hugging Face files (e.g. config, tokenizer, PyTorch model). Use when loading model as `AutoModel.from_pretrained(folder_path)`.",
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
parser.add_argument(
|
|
241
|
+
"-o",
|
|
242
|
+
"--onnx-model-path",
|
|
243
|
+
required=False,
|
|
244
|
+
help="Path to ONNX model",
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
parser.add_argument(
|
|
248
|
+
"-f",
|
|
249
|
+
"--prompts-file",
|
|
250
|
+
required=True,
|
|
251
|
+
default=os.path.join(".", "models", "llama", "prompts.json"),
|
|
252
|
+
help="JSON file containing entries in the format 'prompt length: prompt' where prompt length = tokenized length of prompt",
|
|
253
|
+
)
|
|
254
|
+
|
|
255
|
+
parser.add_argument(
|
|
256
|
+
"--use_buffer_share",
|
|
257
|
+
default=False,
|
|
258
|
+
action="store_true",
|
|
259
|
+
help="Use when GroupQueryAttention (GQA) is in ONNX model",
|
|
260
|
+
)
|
|
261
|
+
|
|
262
|
+
parser.add_argument(
|
|
263
|
+
"--anomaly-filtering",
|
|
264
|
+
default=False,
|
|
265
|
+
action="store_true",
|
|
266
|
+
help="Use this flag to filter anomaly accelerator times for tokens generated. \
|
|
267
|
+
This may give more accurate latency and throughput metrics for tokens generated. \
|
|
268
|
+
Wall-clock metrics are still reported with anomaly times though.",
|
|
269
|
+
),
|
|
270
|
+
|
|
271
|
+
parser.add_argument(
|
|
272
|
+
"-b",
|
|
273
|
+
"--batch-sizes",
|
|
274
|
+
default="1 2",
|
|
275
|
+
)
|
|
276
|
+
|
|
277
|
+
parser.add_argument(
|
|
278
|
+
"-s",
|
|
279
|
+
"--prompt-lengths",
|
|
280
|
+
default="16 64 256 1024",
|
|
281
|
+
)
|
|
282
|
+
|
|
283
|
+
parser.add_argument(
|
|
284
|
+
"-p",
|
|
285
|
+
"--precision",
|
|
286
|
+
required=True,
|
|
287
|
+
type=str,
|
|
288
|
+
default="fp32",
|
|
289
|
+
choices=["int4", "int8", "fp16", "fp32"],
|
|
290
|
+
help="Precision for model. For ONNX models, the model's precision should be set before running this script.",
|
|
291
|
+
)
|
|
292
|
+
|
|
293
|
+
parser.add_argument(
|
|
294
|
+
"-g",
|
|
295
|
+
"--generation-length",
|
|
296
|
+
type=int,
|
|
297
|
+
default=256,
|
|
298
|
+
help="Number of new tokens to generate",
|
|
299
|
+
)
|
|
300
|
+
|
|
301
|
+
parser.add_argument(
|
|
302
|
+
"-d",
|
|
303
|
+
"--device",
|
|
304
|
+
type=str,
|
|
305
|
+
default="cuda" if torch.cuda.is_available() else "cpu",
|
|
306
|
+
choices=["cpu", "cuda"],
|
|
307
|
+
)
|
|
308
|
+
|
|
309
|
+
parser.add_argument("-id", "--device-id", type=int, default=0)
|
|
310
|
+
parser.add_argument("-w", "--warmup-runs", type=int, default=5)
|
|
311
|
+
parser.add_argument("-n", "--num-runs", type=int, default=100)
|
|
312
|
+
parser.add_argument("--seed", type=int, default=2)
|
|
313
|
+
|
|
314
|
+
args = parser.parse_args()
|
|
315
|
+
|
|
316
|
+
# Set seed properties
|
|
317
|
+
np.random.seed(args.seed)
|
|
318
|
+
torch.manual_seed(args.seed)
|
|
319
|
+
|
|
320
|
+
# Set runtime properties
|
|
321
|
+
if "ort" in args.benchmark_type:
|
|
322
|
+
setattr(args, "execution_provider", f"{args.device.upper()}ExecutionProvider") # noqa: B010
|
|
323
|
+
if args.execution_provider == "CUDAExecutionProvider":
|
|
324
|
+
args.execution_provider = (args.execution_provider, {"device_id": args.device_id})
|
|
325
|
+
|
|
326
|
+
# Check that paths have been specified for any benchmarking with ORT
|
|
327
|
+
if args.benchmark_type == "ort":
|
|
328
|
+
assert args.onnx_model_path, "Please specify a path to `--onnx-model-path`"
|
|
329
|
+
|
|
330
|
+
args.batch_sizes = args.batch_sizes.split(" ")
|
|
331
|
+
args.prompt_lengths = args.prompt_lengths.split(" ")
|
|
332
|
+
|
|
333
|
+
# Use FP32 precision for FP32, INT8, INT4 CPU models, use FP16 precision for FP16 and INT4 GPU models
|
|
334
|
+
setattr(args, "onnx_precision", args.precision) # noqa: B010
|
|
335
|
+
args.precision = (
|
|
336
|
+
"fp32" if args.precision in {"int8", "fp32"} or (args.precision == "int4" and args.device == "cpu") else "fp16"
|
|
337
|
+
)
|
|
338
|
+
|
|
339
|
+
target_device = f"cuda:{args.device_id}" if args.device != "cpu" else args.device
|
|
340
|
+
torch_dtype = torch.float16 if args.precision == "fp16" else torch.float32
|
|
341
|
+
engine = "ort" if args.benchmark_type == "ort" else "pt"
|
|
342
|
+
setattr(args, "target_device", target_device) # noqa: B010
|
|
343
|
+
setattr(args, "torch_dtype", torch_dtype) # noqa: B010
|
|
344
|
+
setattr(args, "engine", engine) # noqa: B010
|
|
345
|
+
setattr(args, "use_fp16", args.precision == "fp16") # noqa: B010
|
|
346
|
+
|
|
347
|
+
args.use_buffer_share = args.use_buffer_share and engine == "ort"
|
|
348
|
+
|
|
349
|
+
return args
|
|
350
|
+
|
|
351
|
+
|
|
352
|
+
def main():
|
|
353
|
+
args = get_args()
|
|
354
|
+
setup_logger(False)
|
|
355
|
+
logger.info(args.__dict__)
|
|
356
|
+
|
|
357
|
+
# Get prompts and prompt sizes
|
|
358
|
+
size_to_prompt = None
|
|
359
|
+
with open(args.prompts_file) as f:
|
|
360
|
+
size_to_prompt = json.load(f, object_hook=lambda d: {int(k): v for k, v in d.items()})
|
|
361
|
+
|
|
362
|
+
# Get config, tokenizer, and model
|
|
363
|
+
config = AutoConfig.from_pretrained(
|
|
364
|
+
args.hf_dir_path if args.hf_dir_path != "" else args.model_name,
|
|
365
|
+
cache_dir=args.cache_dir,
|
|
366
|
+
use_auth_token=args.auth,
|
|
367
|
+
trust_remote_code=args.trust,
|
|
368
|
+
)
|
|
369
|
+
tokenizer = AutoTokenizer.from_pretrained(
|
|
370
|
+
args.hf_dir_path if args.hf_dir_path != "" else args.model_name,
|
|
371
|
+
cache_dir=args.cache_dir,
|
|
372
|
+
use_auth_token=args.auth,
|
|
373
|
+
trust_remote_code=args.trust,
|
|
374
|
+
)
|
|
375
|
+
model = get_model(args)
|
|
376
|
+
|
|
377
|
+
all_csv_metrics = []
|
|
378
|
+
for batch_size, prompt_length in itertools.product(args.batch_sizes, args.prompt_lengths):
|
|
379
|
+
batch_size, prompt_length = int(batch_size), int(prompt_length) # noqa: PLW2901
|
|
380
|
+
logger.info(f"Running batch size = {batch_size}, prompt length = {prompt_length}")
|
|
381
|
+
clear_cache()
|
|
382
|
+
max_length = prompt_length + args.generation_length
|
|
383
|
+
|
|
384
|
+
if prompt_length not in size_to_prompt:
|
|
385
|
+
raise NotImplementedError(
|
|
386
|
+
textwrap.dedent(
|
|
387
|
+
f"""
|
|
388
|
+
A prompt of size {prompt_length} was not found in '{args.prompts_file}'. There are a couple of solutions to fix this.
|
|
389
|
+
1) You can change one of the keys in '{args.prompts_file}' to be {prompt_length}.
|
|
390
|
+
If {prompt_length} < actual prompt's length, the benchmark E2E tool will repeat the first word in the prompt until {prompt_length} = actual prompt's length.
|
|
391
|
+
If {prompt_length} > actual prompt's length, the benchmark E2E tool will automatically trim the actual prompt's length so that {prompt_length} = actual prompt's length.
|
|
392
|
+
2) You can add a new key-value entry in '{args.prompts_file}' of the form '{prompt_length}': 'your prompt goes here'.
|
|
393
|
+
"""
|
|
394
|
+
)
|
|
395
|
+
)
|
|
396
|
+
prompt = [size_to_prompt[prompt_length]] * batch_size
|
|
397
|
+
csv_metrics = [batch_size, prompt_length]
|
|
398
|
+
|
|
399
|
+
try:
|
|
400
|
+
# Measure prompt processing
|
|
401
|
+
logger.info("Measuring prompt processing...")
|
|
402
|
+
inputs, outputs = prepare_model_for_inference(args, model, config, tokenizer, prompt_length, prompt)
|
|
403
|
+
accelerator_prompt_latency_s, outputs = run_inference(args, model, args.num_runs, inputs, outputs)
|
|
404
|
+
|
|
405
|
+
# Calculate prompt metrics
|
|
406
|
+
accelerator_prompt_latency_ms = accelerator_prompt_latency_s * 1000
|
|
407
|
+
accelerator_prompt_thrpt = batch_size * (prompt_length / accelerator_prompt_latency_s)
|
|
408
|
+
logger.info(f"Average Latency of Prompt Processing: {accelerator_prompt_latency_ms} ms")
|
|
409
|
+
logger.info(
|
|
410
|
+
f"Average Throughput of Prompt Processing: {batch_size * (prompt_length / accelerator_prompt_latency_s)} tps"
|
|
411
|
+
)
|
|
412
|
+
csv_metrics.extend([accelerator_prompt_latency_ms, accelerator_prompt_thrpt])
|
|
413
|
+
|
|
414
|
+
# Measure token generation
|
|
415
|
+
logger.info("Measuring token generation...")
|
|
416
|
+
clear_cache()
|
|
417
|
+
inputs, outputs = prepare_model_for_inference(args, model, config, tokenizer, prompt_length, prompt)
|
|
418
|
+
|
|
419
|
+
all_token_ids = inputs["input_ids"].clone()
|
|
420
|
+
current_length = all_token_ids.shape[-1]
|
|
421
|
+
num_heads = config.num_key_value_heads
|
|
422
|
+
head_size = (
|
|
423
|
+
config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
|
|
424
|
+
)
|
|
425
|
+
|
|
426
|
+
has_eos = torch.zeros(batch_size, device=args.target_device, dtype=torch.bool)
|
|
427
|
+
|
|
428
|
+
# 0th entry will have prompt accelerator time, 1st entry onwards will have token generation accelerator time
|
|
429
|
+
accelerator_times = []
|
|
430
|
+
sampling_times = [] # cost to sample after each model run
|
|
431
|
+
|
|
432
|
+
wall_clock_start_time = time.perf_counter()
|
|
433
|
+
while current_length <= max_length:
|
|
434
|
+
# Run inference
|
|
435
|
+
accelerator_time_latency_s, outputs = run_inference(args, model, 1, inputs, outputs)
|
|
436
|
+
accelerator_times.append(accelerator_time_latency_s)
|
|
437
|
+
|
|
438
|
+
# Sample with argmax (greedy search)
|
|
439
|
+
sampling_start_time = time.perf_counter()
|
|
440
|
+
if outputs["logits"].shape[1] > 1:
|
|
441
|
+
prompt_end_indices = inputs["attention_mask"].sum(1) - 1
|
|
442
|
+
idxs = (
|
|
443
|
+
prompt_end_indices.unsqueeze(dim=1)
|
|
444
|
+
.repeat(1, config.vocab_size)
|
|
445
|
+
.view(batch_size, 1, config.vocab_size)
|
|
446
|
+
)
|
|
447
|
+
next_token_logits = torch.gather(outputs["logits"], 1, idxs).squeeze()
|
|
448
|
+
else:
|
|
449
|
+
next_token_logits = outputs["logits"][:, -1, :]
|
|
450
|
+
next_tokens = torch.argmax(next_token_logits, dim=-1)
|
|
451
|
+
|
|
452
|
+
# Check if we previously reached EOS token id or if generated token id is EOS token id
|
|
453
|
+
has_eos = has_eos | next_tokens == tokenizer.eos_token_id
|
|
454
|
+
|
|
455
|
+
# Determine which new tokens to add to list of all token ids
|
|
456
|
+
# Add EOS token ids for batch entries that ended early (ragged batching scenario where some batch entries ended early and some haven't)
|
|
457
|
+
tokens_to_add = next_tokens.masked_fill(has_eos, tokenizer.eos_token_id).reshape([batch_size, 1])
|
|
458
|
+
sampling_end_time = time.perf_counter()
|
|
459
|
+
sampling_times.append(sampling_end_time - sampling_start_time)
|
|
460
|
+
|
|
461
|
+
all_token_ids = torch.cat([all_token_ids, tokens_to_add], dim=-1)
|
|
462
|
+
current_length += 1
|
|
463
|
+
|
|
464
|
+
# Update inputs for next inference run
|
|
465
|
+
inputs["input_ids"] = tokens_to_add
|
|
466
|
+
inputs["attention_mask"] = torch.cat(
|
|
467
|
+
[inputs["attention_mask"], (~has_eos).to(torch.int64).reshape(batch_size, 1)], 1
|
|
468
|
+
)
|
|
469
|
+
if "position_ids" in inputs:
|
|
470
|
+
inputs["position_ids"] = torch.max(inputs["position_ids"], dim=1)[0].reshape(batch_size, 1) + 1
|
|
471
|
+
|
|
472
|
+
# Set logits to zeros for next inference run and re-use memory buffer
|
|
473
|
+
if outputs["logits"].shape[1] != 1:
|
|
474
|
+
outputs["logits"] = outputs["logits"][:, :1, :].contiguous()
|
|
475
|
+
outputs["logits"].zero_()
|
|
476
|
+
|
|
477
|
+
# Update KV caches for next inference run
|
|
478
|
+
if args.engine == "pt":
|
|
479
|
+
# Update KV caches for PyTorch
|
|
480
|
+
inputs["past_key_values"] = outputs["past_key_values"]
|
|
481
|
+
elif not args.use_buffer_share:
|
|
482
|
+
# Update KV caches for ONNX Runtime if buffer sharing is not used
|
|
483
|
+
for i in range(config.num_hidden_layers):
|
|
484
|
+
inputs[f"past_key_values.{i}.key"] = outputs[f"present.{i}.key"]
|
|
485
|
+
inputs[f"past_key_values.{i}.value"] = outputs[f"present.{i}.value"]
|
|
486
|
+
|
|
487
|
+
new_sequence_length = inputs["attention_mask"].shape[1]
|
|
488
|
+
for i in range(config.num_hidden_layers):
|
|
489
|
+
present_key = torch.zeros(
|
|
490
|
+
batch_size,
|
|
491
|
+
num_heads,
|
|
492
|
+
new_sequence_length,
|
|
493
|
+
head_size,
|
|
494
|
+
device=args.target_device,
|
|
495
|
+
dtype=args.torch_dtype,
|
|
496
|
+
)
|
|
497
|
+
present_value = torch.zeros(
|
|
498
|
+
batch_size,
|
|
499
|
+
num_heads,
|
|
500
|
+
new_sequence_length,
|
|
501
|
+
head_size,
|
|
502
|
+
device=args.target_device,
|
|
503
|
+
dtype=args.torch_dtype,
|
|
504
|
+
)
|
|
505
|
+
outputs.update(
|
|
506
|
+
{
|
|
507
|
+
f"present.{i}.key": present_key.contiguous(),
|
|
508
|
+
f"present.{i}.value": present_value.contiguous(),
|
|
509
|
+
}
|
|
510
|
+
)
|
|
511
|
+
|
|
512
|
+
wall_clock_end_time = time.perf_counter()
|
|
513
|
+
|
|
514
|
+
# Filter out any anomaly accelerator times (e.g. for `torch.compile`)
|
|
515
|
+
accelerator_times.pop(0) # Remove prompt processing time
|
|
516
|
+
if args.anomaly_filtering:
|
|
517
|
+
anomaly_threshold_factor = 10
|
|
518
|
+
min_time_s = min(accelerator_times)
|
|
519
|
+
orig_size = len(accelerator_times)
|
|
520
|
+
accelerator_times = list(
|
|
521
|
+
filter(lambda acc_time: acc_time < anomaly_threshold_factor * min_time_s, accelerator_times)
|
|
522
|
+
)
|
|
523
|
+
new_size = len(accelerator_times)
|
|
524
|
+
logger.info(
|
|
525
|
+
f"Filtered out {orig_size - new_size} anomaly accelerator times that are {anomaly_threshold_factor}x greater than {min_time_s * 1000} ms..."
|
|
526
|
+
)
|
|
527
|
+
|
|
528
|
+
#######################################################
|
|
529
|
+
# Calculate sampling and first token generated metrics
|
|
530
|
+
#######################################################
|
|
531
|
+
|
|
532
|
+
# Calculate sampling metrics
|
|
533
|
+
avg_sampling_latency_s = sum(sampling_times) / len(sampling_times)
|
|
534
|
+
avg_sampling_latency_ms = avg_sampling_latency_s * 1000
|
|
535
|
+
avg_sampling_thrpt = batch_size * (1 / avg_sampling_latency_s)
|
|
536
|
+
logger.info(f"Average Latency of Sampling: {avg_sampling_latency_ms} ms")
|
|
537
|
+
logger.info(f"Average Throughput of Sampling: {avg_sampling_thrpt} tps")
|
|
538
|
+
|
|
539
|
+
# Calculate first token generated metrics
|
|
540
|
+
first_token_latency_s = accelerator_times[0]
|
|
541
|
+
first_token_latency_ms = first_token_latency_s * 1000
|
|
542
|
+
first_token_thrpt = batch_size * (1 / first_token_latency_s)
|
|
543
|
+
logger.info(f"Latency of First Token Generated: {first_token_latency_ms} ms")
|
|
544
|
+
logger.info(f"Throughput of First Token Generated: {first_token_thrpt} tps")
|
|
545
|
+
|
|
546
|
+
####################################################
|
|
547
|
+
# Calculate first `halfway` token generated metrics
|
|
548
|
+
####################################################
|
|
549
|
+
|
|
550
|
+
halfway = args.generation_length // 2
|
|
551
|
+
halfway_token_latency_s = sum(accelerator_times[:halfway]) / len(accelerator_times[:halfway])
|
|
552
|
+
halfway_token_latency_ms = halfway_token_latency_s * 1000
|
|
553
|
+
halfway_token_thrpt = batch_size * (1 / halfway_token_latency_s)
|
|
554
|
+
logger.info(f"Average Latency of First {halfway} Tokens Generated: {halfway_token_latency_ms} ms")
|
|
555
|
+
logger.info(f"Average Throughput of First {halfway} Tokens Generated: {halfway_token_thrpt} tps")
|
|
556
|
+
|
|
557
|
+
#########################################
|
|
558
|
+
# Calculate all tokens generated metrics
|
|
559
|
+
#########################################
|
|
560
|
+
|
|
561
|
+
all_token_latency_s = sum(accelerator_times) / len(accelerator_times)
|
|
562
|
+
all_token_latency_ms = all_token_latency_s * 1000
|
|
563
|
+
all_token_thrpt = batch_size * (1 / all_token_latency_s)
|
|
564
|
+
logger.info(
|
|
565
|
+
f"Average Latency of First {args.generation_length} Tokens Generated: {all_token_latency_ms} ms"
|
|
566
|
+
)
|
|
567
|
+
logger.info(f"Average Throughput of First {args.generation_length} Tokens Generated: {all_token_thrpt} tps")
|
|
568
|
+
|
|
569
|
+
###############################
|
|
570
|
+
# Calculate wall clock metrics
|
|
571
|
+
###############################
|
|
572
|
+
|
|
573
|
+
wall_clock_latency_s = wall_clock_end_time - wall_clock_start_time
|
|
574
|
+
wall_clock_thrpt = batch_size * ((prompt_length + args.generation_length) / wall_clock_latency_s)
|
|
575
|
+
logger.info(f"Wall-Clock Latency: {wall_clock_latency_s} s")
|
|
576
|
+
logger.info(
|
|
577
|
+
f"Wall-Clock Throughput: {batch_size * ((prompt_length + args.generation_length) / wall_clock_latency_s)} tps"
|
|
578
|
+
)
|
|
579
|
+
|
|
580
|
+
# Add metrics to CSV
|
|
581
|
+
logger.info("Adding results to CSV")
|
|
582
|
+
csv_metrics.extend(
|
|
583
|
+
[
|
|
584
|
+
avg_sampling_latency_ms,
|
|
585
|
+
avg_sampling_thrpt,
|
|
586
|
+
first_token_latency_ms,
|
|
587
|
+
first_token_thrpt,
|
|
588
|
+
halfway_token_latency_ms,
|
|
589
|
+
halfway_token_thrpt,
|
|
590
|
+
all_token_latency_ms,
|
|
591
|
+
all_token_thrpt,
|
|
592
|
+
wall_clock_latency_s,
|
|
593
|
+
wall_clock_thrpt,
|
|
594
|
+
]
|
|
595
|
+
)
|
|
596
|
+
all_csv_metrics.append(csv_metrics)
|
|
597
|
+
|
|
598
|
+
except Exception as e:
|
|
599
|
+
logger.info(f"Could not benchmark at batch size = {batch_size}, prompt length = {prompt_length} - {e}")
|
|
600
|
+
|
|
601
|
+
filename = f"benchmark_{args.engine}_e2e_{datetime.datetime.now():%Y-%m-%d_%H:%M:%S}.csv"
|
|
602
|
+
save_results(all_csv_metrics, filename, args.generation_length)
|
|
603
|
+
|
|
604
|
+
|
|
605
|
+
if __name__ == "__main__":
|
|
606
|
+
main()
|