onnxruntime-directml 1.20.0__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxruntime/LICENSE +21 -0
- onnxruntime/Privacy.md +21 -0
- onnxruntime/ThirdPartyNotices.txt +6508 -0
- onnxruntime/__init__.py +78 -0
- onnxruntime/backend/__init__.py +6 -0
- onnxruntime/backend/backend.py +174 -0
- onnxruntime/backend/backend_rep.py +53 -0
- onnxruntime/capi/DirectML.dll +0 -0
- onnxruntime/capi/__init__.py +4 -0
- onnxruntime/capi/_ld_preload.py +7 -0
- onnxruntime/capi/_pybind_state.py +33 -0
- onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
- onnxruntime/capi/onnxruntime.dll +0 -0
- onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
- onnxruntime/capi/onnxruntime_inference_collection.py +1108 -0
- onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
- onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
- onnxruntime/capi/onnxruntime_validation.py +150 -0
- onnxruntime/capi/version_info.py +2 -0
- onnxruntime/datasets/__init__.py +17 -0
- onnxruntime/datasets/logreg_iris.onnx +0 -0
- onnxruntime/datasets/mul_1.onnx +0 -0
- onnxruntime/datasets/sigmoid.onnx +13 -0
- onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
- onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
- onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
- onnxruntime/quantization/__init__.py +16 -0
- onnxruntime/quantization/base_quantizer.py +532 -0
- onnxruntime/quantization/calibrate.py +1245 -0
- onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
- onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
- onnxruntime/quantization/execution_providers/qnn/preprocess.py +307 -0
- onnxruntime/quantization/execution_providers/qnn/quant_config.py +387 -0
- onnxruntime/quantization/fusions/__init__.py +3 -0
- onnxruntime/quantization/fusions/fusion.py +311 -0
- onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
- onnxruntime/quantization/fusions/fusion_layernorm.py +135 -0
- onnxruntime/quantization/matmul_4bits_quantizer.py +1480 -0
- onnxruntime/quantization/matmul_bnb4_quantizer.py +240 -0
- onnxruntime/quantization/onnx_model.py +580 -0
- onnxruntime/quantization/onnx_quantizer.py +1008 -0
- onnxruntime/quantization/operators/__init__.py +2 -0
- onnxruntime/quantization/operators/activation.py +119 -0
- onnxruntime/quantization/operators/argmax.py +18 -0
- onnxruntime/quantization/operators/attention.py +73 -0
- onnxruntime/quantization/operators/base_operator.py +26 -0
- onnxruntime/quantization/operators/binary_op.py +72 -0
- onnxruntime/quantization/operators/concat.py +62 -0
- onnxruntime/quantization/operators/conv.py +258 -0
- onnxruntime/quantization/operators/direct_q8.py +78 -0
- onnxruntime/quantization/operators/embed_layernorm.py +121 -0
- onnxruntime/quantization/operators/gather.py +64 -0
- onnxruntime/quantization/operators/gavgpool.py +62 -0
- onnxruntime/quantization/operators/gemm.py +166 -0
- onnxruntime/quantization/operators/lstm.py +117 -0
- onnxruntime/quantization/operators/matmul.py +231 -0
- onnxruntime/quantization/operators/maxpool.py +34 -0
- onnxruntime/quantization/operators/norm.py +40 -0
- onnxruntime/quantization/operators/pad.py +100 -0
- onnxruntime/quantization/operators/pooling.py +67 -0
- onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
- onnxruntime/quantization/operators/resize.py +34 -0
- onnxruntime/quantization/operators/softmax.py +74 -0
- onnxruntime/quantization/operators/split.py +63 -0
- onnxruntime/quantization/operators/where.py +87 -0
- onnxruntime/quantization/preprocess.py +141 -0
- onnxruntime/quantization/qdq_loss_debug.py +389 -0
- onnxruntime/quantization/qdq_quantizer.py +1187 -0
- onnxruntime/quantization/quant_utils.py +891 -0
- onnxruntime/quantization/quantize.py +748 -0
- onnxruntime/quantization/registry.py +106 -0
- onnxruntime/quantization/shape_inference.py +187 -0
- onnxruntime/quantization/tensor_quant_overrides.py +516 -0
- onnxruntime/tools/__init__.py +10 -0
- onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
- onnxruntime/tools/convert_onnx_models_to_ort.py +377 -0
- onnxruntime/tools/file_utils.py +46 -0
- onnxruntime/tools/logger.py +11 -0
- onnxruntime/tools/make_dynamic_shape_fixed.py +72 -0
- onnxruntime/tools/mobile_helpers/__init__.py +0 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +33 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
- onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
- onnxruntime/tools/mobile_helpers/usability_checker.py +739 -0
- onnxruntime/tools/offline_tuning.py +169 -0
- onnxruntime/tools/onnx_model_utils.py +413 -0
- onnxruntime/tools/onnx_randomizer.py +85 -0
- onnxruntime/tools/onnxruntime_test.py +164 -0
- onnxruntime/tools/optimize_onnx_model.py +55 -0
- onnxruntime/tools/ort_format_model/__init__.py +25 -0
- onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +663 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
- onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
- onnxruntime/tools/ort_format_model/types.py +84 -0
- onnxruntime/tools/ort_format_model/utils.py +62 -0
- onnxruntime/tools/pytorch_export_contrib_ops.py +108 -0
- onnxruntime/tools/pytorch_export_helpers.py +131 -0
- onnxruntime/tools/qdq_helpers/__init__.py +0 -0
- onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
- onnxruntime/tools/reduced_build_config_parser.py +202 -0
- onnxruntime/tools/symbolic_shape_infer.py +3016 -0
- onnxruntime/tools/update_onnx_opset.py +31 -0
- onnxruntime/transformers/__init__.py +8 -0
- onnxruntime/transformers/affinity_helper.py +40 -0
- onnxruntime/transformers/benchmark.py +944 -0
- onnxruntime/transformers/benchmark_helper.py +646 -0
- onnxruntime/transformers/bert_perf_test.py +634 -0
- onnxruntime/transformers/bert_test_data.py +642 -0
- onnxruntime/transformers/compare_bert_results.py +246 -0
- onnxruntime/transformers/constants.py +47 -0
- onnxruntime/transformers/convert_generation.py +3124 -0
- onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
- onnxruntime/transformers/convert_to_packing_mode.py +387 -0
- onnxruntime/transformers/dynamo_onnx_helper.py +104 -0
- onnxruntime/transformers/float16.py +501 -0
- onnxruntime/transformers/fusion_attention.py +1235 -0
- onnxruntime/transformers/fusion_attention_clip.py +257 -0
- onnxruntime/transformers/fusion_attention_sam2.py +534 -0
- onnxruntime/transformers/fusion_attention_unet.py +1304 -0
- onnxruntime/transformers/fusion_attention_vae.py +301 -0
- onnxruntime/transformers/fusion_bart_attention.py +640 -0
- onnxruntime/transformers/fusion_base.py +137 -0
- onnxruntime/transformers/fusion_bias_add.py +58 -0
- onnxruntime/transformers/fusion_biasgelu.py +66 -0
- onnxruntime/transformers/fusion_biassplitgelu.py +111 -0
- onnxruntime/transformers/fusion_conformer_attention.py +143 -0
- onnxruntime/transformers/fusion_embedlayer.py +811 -0
- onnxruntime/transformers/fusion_fastgelu.py +360 -0
- onnxruntime/transformers/fusion_gelu.py +259 -0
- onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
- onnxruntime/transformers/fusion_gemmfastgelu.py +122 -0
- onnxruntime/transformers/fusion_gpt_attention.py +546 -0
- onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
- onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
- onnxruntime/transformers/fusion_group_norm.py +179 -0
- onnxruntime/transformers/fusion_layernorm.py +465 -0
- onnxruntime/transformers/fusion_nhwc_conv.py +100 -0
- onnxruntime/transformers/fusion_options.py +340 -0
- onnxruntime/transformers/fusion_qordered_attention.py +421 -0
- onnxruntime/transformers/fusion_qordered_gelu.py +119 -0
- onnxruntime/transformers/fusion_qordered_layernorm.py +123 -0
- onnxruntime/transformers/fusion_qordered_matmul.py +217 -0
- onnxruntime/transformers/fusion_quickgelu.py +74 -0
- onnxruntime/transformers/fusion_reshape.py +173 -0
- onnxruntime/transformers/fusion_rotary_attention.py +1592 -0
- onnxruntime/transformers/fusion_shape.py +110 -0
- onnxruntime/transformers/fusion_simplified_layernorm.py +159 -0
- onnxruntime/transformers/fusion_skip_group_norm.py +255 -0
- onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
- onnxruntime/transformers/fusion_transpose.py +168 -0
- onnxruntime/transformers/fusion_utils.py +307 -0
- onnxruntime/transformers/huggingface_models.py +167 -0
- onnxruntime/transformers/import_utils.py +20 -0
- onnxruntime/transformers/io_binding_helper.py +442 -0
- onnxruntime/transformers/large_model_exporter.py +395 -0
- onnxruntime/transformers/machine_info.py +221 -0
- onnxruntime/transformers/metrics.py +164 -0
- onnxruntime/transformers/models/bart/__init__.py +12 -0
- onnxruntime/transformers/models/bart/export.py +98 -0
- onnxruntime/transformers/models/bert/__init__.py +12 -0
- onnxruntime/transformers/models/bert/eval_squad.py +329 -0
- onnxruntime/transformers/models/gpt2/__init__.py +12 -0
- onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
- onnxruntime/transformers/models/gpt2/convert_to_onnx.py +561 -0
- onnxruntime/transformers/models/gpt2/gpt2_helper.py +1032 -0
- onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
- onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
- onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
- onnxruntime/transformers/models/llama/__init__.py +12 -0
- onnxruntime/transformers/models/llama/benchmark.py +703 -0
- onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
- onnxruntime/transformers/models/llama/benchmark_e2e.py +606 -0
- onnxruntime/transformers/models/llama/convert_to_onnx.py +1027 -0
- onnxruntime/transformers/models/llama/dist_settings.py +57 -0
- onnxruntime/transformers/models/llama/llama_inputs.py +503 -0
- onnxruntime/transformers/models/llama/llama_parity.py +309 -0
- onnxruntime/transformers/models/llama/llama_torch.py +47 -0
- onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
- onnxruntime/transformers/models/longformer/__init__.py +12 -0
- onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
- onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
- onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
- onnxruntime/transformers/models/longformer/longformer_helper.py +77 -0
- onnxruntime/transformers/models/phi2/__init__.py +12 -0
- onnxruntime/transformers/models/phi2/convert_to_onnx.py +576 -0
- onnxruntime/transformers/models/phi2/inference_example.py +414 -0
- onnxruntime/transformers/models/sam2/__init__.py +12 -0
- onnxruntime/transformers/models/sam2/benchmark_sam2.py +625 -0
- onnxruntime/transformers/models/sam2/convert_to_onnx.py +260 -0
- onnxruntime/transformers/models/sam2/image_decoder.py +273 -0
- onnxruntime/transformers/models/sam2/image_encoder.py +186 -0
- onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
- onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
- onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
- onnxruntime/transformers/models/sam2/sam2_demo.py +322 -0
- onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +280 -0
- onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
- onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark.py +1429 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +102 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +268 -0
- onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1319 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1181 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder.py +296 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +388 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
- onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +350 -0
- onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
- onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
- onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
- onnxruntime/transformers/models/t5/__init__.py +12 -0
- onnxruntime/transformers/models/t5/convert_to_onnx.py +278 -0
- onnxruntime/transformers/models/t5/past_helper.py +150 -0
- onnxruntime/transformers/models/t5/t5_decoder.py +438 -0
- onnxruntime/transformers/models/t5/t5_encoder.py +171 -0
- onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +299 -0
- onnxruntime/transformers/models/t5/t5_helper.py +272 -0
- onnxruntime/transformers/models/whisper/__init__.py +12 -0
- onnxruntime/transformers/models/whisper/benchmark.py +610 -0
- onnxruntime/transformers/models/whisper/benchmark_all.py +528 -0
- onnxruntime/transformers/models/whisper/convert_to_onnx.py +536 -0
- onnxruntime/transformers/models/whisper/whisper_chain.py +329 -0
- onnxruntime/transformers/models/whisper/whisper_decoder.py +402 -0
- onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
- onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +306 -0
- onnxruntime/transformers/models/whisper/whisper_helper.py +524 -0
- onnxruntime/transformers/models/whisper/whisper_openai_helper.py +84 -0
- onnxruntime/transformers/onnx_exporter.py +717 -0
- onnxruntime/transformers/onnx_model.py +1569 -0
- onnxruntime/transformers/onnx_model_bart.py +142 -0
- onnxruntime/transformers/onnx_model_bert.py +481 -0
- onnxruntime/transformers/onnx_model_bert_keras.py +475 -0
- onnxruntime/transformers/onnx_model_bert_tf.py +589 -0
- onnxruntime/transformers/onnx_model_clip.py +40 -0
- onnxruntime/transformers/onnx_model_conformer.py +33 -0
- onnxruntime/transformers/onnx_model_gpt2.py +101 -0
- onnxruntime/transformers/onnx_model_phi.py +930 -0
- onnxruntime/transformers/onnx_model_sam2.py +138 -0
- onnxruntime/transformers/onnx_model_t5.py +791 -0
- onnxruntime/transformers/onnx_model_tnlr.py +227 -0
- onnxruntime/transformers/onnx_model_unet.py +259 -0
- onnxruntime/transformers/onnx_model_vae.py +43 -0
- onnxruntime/transformers/onnx_utils.py +55 -0
- onnxruntime/transformers/optimizer.py +612 -0
- onnxruntime/transformers/profiler.py +725 -0
- onnxruntime/transformers/quantize_helper.py +76 -0
- onnxruntime/transformers/shape_infer_helper.py +122 -0
- onnxruntime/transformers/shape_optimizer.py +401 -0
- onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
- onnxruntime_directml-1.20.0.dist-info/METADATA +187 -0
- onnxruntime_directml-1.20.0.dist-info/RECORD +305 -0
- onnxruntime_directml-1.20.0.dist-info/WHEEL +5 -0
- onnxruntime_directml-1.20.0.dist-info/entry_points.txt +2 -0
- onnxruntime_directml-1.20.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License. See License.txt in the project root for
|
|
4
|
+
# license information.
|
|
5
|
+
# --------------------------------------------------------------------------
|
|
6
|
+
import os
|
|
7
|
+
|
|
8
|
+
import torch.distributed as dist
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def init_dist():
|
|
12
|
+
if "LOCAL_RANK" in os.environ:
|
|
13
|
+
int(os.environ["LOCAL_RANK"])
|
|
14
|
+
rank = int(os.environ["RANK"])
|
|
15
|
+
world_size = int(os.environ["WORLD_SIZE"])
|
|
16
|
+
|
|
17
|
+
dist.init_process_group("nccl", init_method="tcp://127.0.0.1:7645", world_size=world_size, rank=rank)
|
|
18
|
+
elif "OMPI_COMM_WORLD_LOCAL_RANK" in os.environ:
|
|
19
|
+
int(os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK", 0))
|
|
20
|
+
rank = int(os.environ.get("OMPI_COMM_WORLD_RANK", 0))
|
|
21
|
+
world_size = int(os.environ.get("OMPI_COMM_WORLD_SIZE", 1))
|
|
22
|
+
|
|
23
|
+
dist.init_process_group("nccl", init_method="tcp://127.0.0.1:7647", world_size=world_size, rank=rank)
|
|
24
|
+
else:
|
|
25
|
+
# don't need to do init for single process
|
|
26
|
+
pass
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def _get_comm():
|
|
30
|
+
try:
|
|
31
|
+
from mpi4py import MPI
|
|
32
|
+
|
|
33
|
+
comm = MPI.COMM_WORLD
|
|
34
|
+
return comm
|
|
35
|
+
except ImportError:
|
|
36
|
+
return None
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def get_rank():
|
|
40
|
+
comm = _get_comm()
|
|
41
|
+
return comm.Get_rank() if comm is not None else 0
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def get_size():
|
|
45
|
+
comm = _get_comm()
|
|
46
|
+
return comm.Get_size() if comm is not None else 1
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def barrier():
|
|
50
|
+
comm = _get_comm()
|
|
51
|
+
if comm is not None:
|
|
52
|
+
comm.Barrier()
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def print_out(*args):
|
|
56
|
+
if get_rank() == 0:
|
|
57
|
+
print(*args)
|
|
@@ -0,0 +1,503 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License. See License.txt in the project root for
|
|
4
|
+
# license information.
|
|
5
|
+
# --------------------------------------------------------------------------
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
import torch
|
|
10
|
+
from transformers import AutoConfig, AutoTokenizer
|
|
11
|
+
|
|
12
|
+
from onnxruntime import InferenceSession, OrtValue
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
# Get position_ids from attention_mask
|
|
16
|
+
def get_position_ids(attention_mask: torch.Tensor, use_past_kv: bool):
|
|
17
|
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
|
18
|
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
|
19
|
+
if use_past_kv:
|
|
20
|
+
# Shape: (batch_size, 1)
|
|
21
|
+
position_ids = position_ids[:, -1].unsqueeze(-1)
|
|
22
|
+
|
|
23
|
+
# Shape: (batch_size, sequence_length)
|
|
24
|
+
return position_ids
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
# Inputs for first pass to get initial past_key_values
|
|
28
|
+
# input_ids: (batch_size, sequence_length)
|
|
29
|
+
# attention_mask: (batch_size, sequence_length)
|
|
30
|
+
# position_ids: (batch_size, sequence_length)
|
|
31
|
+
def get_sample_inputs(
|
|
32
|
+
config: AutoConfig,
|
|
33
|
+
device: torch.device,
|
|
34
|
+
batch_size: int,
|
|
35
|
+
seq_len: int,
|
|
36
|
+
engine: str = "pt",
|
|
37
|
+
return_dict: bool = False,
|
|
38
|
+
):
|
|
39
|
+
input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seq_len), dtype=torch.int64)
|
|
40
|
+
attention_mask = torch.ones(batch_size, seq_len, dtype=torch.int64)
|
|
41
|
+
position_ids = get_position_ids(attention_mask, use_past_kv=False)
|
|
42
|
+
|
|
43
|
+
# Convert inputs to NumPy (for ORT) or send to device (for PyTorch)
|
|
44
|
+
input_ids = input_ids.numpy() if engine == "ort" else input_ids.to(device)
|
|
45
|
+
attention_mask = attention_mask.numpy() if engine == "ort" else attention_mask.to(device)
|
|
46
|
+
position_ids = position_ids.numpy() if engine == "ort" else position_ids.to(device)
|
|
47
|
+
|
|
48
|
+
if not return_dict:
|
|
49
|
+
# For export
|
|
50
|
+
return (input_ids, attention_mask, position_ids)
|
|
51
|
+
|
|
52
|
+
inputs = {
|
|
53
|
+
"input_ids": input_ids,
|
|
54
|
+
"attention_mask": attention_mask,
|
|
55
|
+
"position_ids": position_ids,
|
|
56
|
+
}
|
|
57
|
+
return inputs
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
# Inputs for subsequent passes with past_key_values
|
|
61
|
+
# input_ids: (batch_size, 1)
|
|
62
|
+
# attention_mask: (batch_size, past_sequence_length + 1)
|
|
63
|
+
# position_ids: (batch_size, 1)
|
|
64
|
+
# past_key: (batch_size, num_heads, past_sequence_length, head_size)
|
|
65
|
+
# past_value: (batch_size, num_heads, past_sequence_length, head_size)
|
|
66
|
+
def get_sample_with_past_kv_inputs(
|
|
67
|
+
config: AutoConfig,
|
|
68
|
+
device: torch.device,
|
|
69
|
+
batch_size: int,
|
|
70
|
+
past_seq_len: int,
|
|
71
|
+
use_fp16: bool = False,
|
|
72
|
+
engine: str = "pt",
|
|
73
|
+
return_dict: bool = False,
|
|
74
|
+
world_size: int = 1,
|
|
75
|
+
):
|
|
76
|
+
input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, 1), dtype=torch.int64)
|
|
77
|
+
attention_mask = torch.ones(batch_size, past_seq_len + 1, dtype=torch.int64)
|
|
78
|
+
# position_ids is of shape (batch_size, 1)
|
|
79
|
+
position_ids = get_position_ids(attention_mask, use_past_kv=True)
|
|
80
|
+
past_kv = get_past_kv_inputs(config, batch_size, past_seq_len, use_fp16, world_size=world_size)
|
|
81
|
+
|
|
82
|
+
# Convert inputs to NumPy (for ORT) or send to device (for PyTorch)
|
|
83
|
+
input_ids = input_ids.numpy() if engine == "ort" else input_ids.to(device)
|
|
84
|
+
attention_mask = attention_mask.numpy() if engine == "ort" else attention_mask.to(device)
|
|
85
|
+
position_ids = position_ids.numpy() if engine == "ort" else position_ids.to(device)
|
|
86
|
+
past_kv = (
|
|
87
|
+
flatten_past_kv_inputs(past_kv)
|
|
88
|
+
if engine == "ort"
|
|
89
|
+
else list(map(lambda kv: (kv[0].to(device), kv[1].to(device)), past_kv))
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
if not return_dict:
|
|
93
|
+
# For export
|
|
94
|
+
assert isinstance(past_kv, list)
|
|
95
|
+
return (input_ids, attention_mask, position_ids, past_kv)
|
|
96
|
+
|
|
97
|
+
inputs = {
|
|
98
|
+
"input_ids": input_ids,
|
|
99
|
+
"attention_mask": attention_mask,
|
|
100
|
+
"position_ids": position_ids,
|
|
101
|
+
}
|
|
102
|
+
if engine == "ort":
|
|
103
|
+
assert isinstance(past_kv, dict)
|
|
104
|
+
inputs.update(past_kv)
|
|
105
|
+
else:
|
|
106
|
+
assert isinstance(past_kv, list)
|
|
107
|
+
inputs["past_key_values"] = past_kv
|
|
108
|
+
|
|
109
|
+
return inputs
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
# Inputs for all passes with past_key_values
|
|
113
|
+
# input_ids: (batch_size, sequence_length)
|
|
114
|
+
# attention_mask: (batch_size, past_sequence_length + sequence_length)
|
|
115
|
+
# position_ids: (batch_size, sequence_length)
|
|
116
|
+
# past_key: (batch_size, num_heads, kv_sequence_length, head_size)
|
|
117
|
+
# For models with GQA, kv_sequence_length = max_sequence_length
|
|
118
|
+
# For models without GQA, kv_sequence_length = past_sequence_length
|
|
119
|
+
# past_value: (batch_size, num_heads, kv_sequence_length, head_size)
|
|
120
|
+
# For models with GQA, kv_sequence_length = max_sequence_length
|
|
121
|
+
# For models without GQA, kv_sequence_length = past_sequence_length
|
|
122
|
+
def get_merged_sample_with_past_kv_inputs(
|
|
123
|
+
config: AutoConfig,
|
|
124
|
+
device: torch.device,
|
|
125
|
+
batch_size: int,
|
|
126
|
+
seq_len: int,
|
|
127
|
+
past_seq_len: int,
|
|
128
|
+
max_seq_len: int,
|
|
129
|
+
use_fp16: bool = False,
|
|
130
|
+
use_buffer_share: bool = False,
|
|
131
|
+
engine: str = "pt",
|
|
132
|
+
return_dict: bool = False,
|
|
133
|
+
world_size: int = 1,
|
|
134
|
+
):
|
|
135
|
+
input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seq_len), dtype=torch.int64)
|
|
136
|
+
attention_mask = torch.ones(batch_size, past_seq_len + seq_len, dtype=torch.int64)
|
|
137
|
+
# position_ids is of shape (batch_size, seq_len) for prompt generation, (batch_size, 1) for token generation
|
|
138
|
+
position_ids = get_position_ids(attention_mask, use_past_kv=(past_seq_len != 0))
|
|
139
|
+
past_kv = get_past_kv_inputs(config, batch_size, past_seq_len, use_fp16, world_size=world_size)
|
|
140
|
+
|
|
141
|
+
# Convert inputs to NumPy (for ORT) or send to device (for PyTorch)
|
|
142
|
+
input_ids = input_ids.numpy() if engine == "ort" else input_ids.to(device)
|
|
143
|
+
attention_mask = attention_mask.numpy() if engine == "ort" else attention_mask.to(device)
|
|
144
|
+
position_ids = position_ids.numpy() if engine == "ort" else position_ids.to(device)
|
|
145
|
+
past_kv = (
|
|
146
|
+
flatten_past_kv_inputs(past_kv)
|
|
147
|
+
if engine == "ort"
|
|
148
|
+
else list(map(lambda kv: (kv[0].to(device), kv[1].to(device)), past_kv))
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
if not return_dict:
|
|
152
|
+
# For export
|
|
153
|
+
assert isinstance(past_kv, list)
|
|
154
|
+
return (input_ids, attention_mask, position_ids, past_kv)
|
|
155
|
+
|
|
156
|
+
inputs = {
|
|
157
|
+
"input_ids": input_ids,
|
|
158
|
+
"attention_mask": attention_mask,
|
|
159
|
+
"position_ids": position_ids,
|
|
160
|
+
}
|
|
161
|
+
if engine == "ort":
|
|
162
|
+
assert isinstance(past_kv, dict)
|
|
163
|
+
inputs.update(past_kv)
|
|
164
|
+
|
|
165
|
+
if use_buffer_share:
|
|
166
|
+
inputs = enable_past_present_share_buffer(inputs, past_seq_len, max_seq_len)
|
|
167
|
+
|
|
168
|
+
else:
|
|
169
|
+
assert isinstance(past_kv, list)
|
|
170
|
+
inputs["past_key_values"] = past_kv
|
|
171
|
+
|
|
172
|
+
return inputs
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
# Inputs for Microsoft export from https://github.com/microsoft/Llama-2-Onnx
|
|
176
|
+
def get_msft_sample_inputs(
|
|
177
|
+
config: AutoConfig,
|
|
178
|
+
batch_size: int,
|
|
179
|
+
past_seq_len: int,
|
|
180
|
+
seq_len: int,
|
|
181
|
+
max_seq_len: int,
|
|
182
|
+
use_fp16: bool,
|
|
183
|
+
use_buffer_share: bool,
|
|
184
|
+
split_kv: bool,
|
|
185
|
+
):
|
|
186
|
+
np_dtype = np.float16 if use_fp16 else np.float32
|
|
187
|
+
head_size = config.hidden_size // config.num_attention_heads
|
|
188
|
+
|
|
189
|
+
if not split_kv:
|
|
190
|
+
ort_inputs = {
|
|
191
|
+
"x": np.random.rand(batch_size, seq_len, config.hidden_size).astype(np_dtype),
|
|
192
|
+
"attn_mask": (-10000.0 * np.triu(np.ones((batch_size, max_seq_len, max_seq_len)), k=1)).astype(np_dtype),
|
|
193
|
+
"k_cache": np.random.rand(
|
|
194
|
+
batch_size, config.num_hidden_layers, past_seq_len, config.num_attention_heads, head_size
|
|
195
|
+
).astype(np_dtype),
|
|
196
|
+
"v_cache": np.random.rand(
|
|
197
|
+
batch_size, config.num_hidden_layers, past_seq_len, config.num_attention_heads, head_size
|
|
198
|
+
).astype(np_dtype),
|
|
199
|
+
"pos": np.array(past_seq_len, dtype=np.int64),
|
|
200
|
+
}
|
|
201
|
+
else:
|
|
202
|
+
ort_inputs = {
|
|
203
|
+
"x": np.random.rand(batch_size, seq_len, config.hidden_size).astype(np_dtype),
|
|
204
|
+
"attn_mask": (np.triu(np.ones((batch_size, max_seq_len, max_seq_len), dtype=np.int32), k=1) - 1).astype(
|
|
205
|
+
np.int32
|
|
206
|
+
),
|
|
207
|
+
"pos": np.array(past_seq_len, dtype=np.int64),
|
|
208
|
+
}
|
|
209
|
+
for i in range(config.num_hidden_layers):
|
|
210
|
+
ort_inputs.update(
|
|
211
|
+
{
|
|
212
|
+
f"k_{i}_cache": np.random.rand(
|
|
213
|
+
batch_size, config.num_attention_heads, past_seq_len, head_size
|
|
214
|
+
).astype(np_dtype),
|
|
215
|
+
f"v_{i}_cache": np.random.rand(
|
|
216
|
+
batch_size, config.num_attention_heads, past_seq_len, head_size
|
|
217
|
+
).astype(np_dtype),
|
|
218
|
+
}
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
if use_buffer_share:
|
|
222
|
+
ort_inputs = enable_past_present_share_buffer(ort_inputs, past_seq_len, max_seq_len)
|
|
223
|
+
|
|
224
|
+
return ort_inputs
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
# Create past_key_values
|
|
228
|
+
# Each is of shape (batch_size, num_heads, past_sequence_length, head_size)
|
|
229
|
+
def get_past_kv_inputs(config: AutoConfig, batch_size: int, past_seq_len: int, use_fp16: bool, world_size: int = 1):
|
|
230
|
+
num_heads = config.num_key_value_heads // world_size
|
|
231
|
+
head_size = config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
|
|
232
|
+
torch_dtype = torch.float16 if use_fp16 else torch.float32
|
|
233
|
+
past_kv = [
|
|
234
|
+
(
|
|
235
|
+
torch.rand(batch_size, num_heads, past_seq_len, head_size, dtype=torch_dtype),
|
|
236
|
+
torch.rand(batch_size, num_heads, past_seq_len, head_size, dtype=torch_dtype),
|
|
237
|
+
)
|
|
238
|
+
for _ in range(config.num_hidden_layers)
|
|
239
|
+
]
|
|
240
|
+
return past_kv
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
# Convert list of past_key_values to dict of past_key and past_value
|
|
244
|
+
def flatten_past_kv_inputs(past_key_values: list[tuple[torch.Tensor, torch.Tensor]]):
|
|
245
|
+
past_kv = {}
|
|
246
|
+
for i, (past_k, past_v) in enumerate(past_key_values):
|
|
247
|
+
past_kv[f"past_key_values.{i}.key"] = past_k.detach().cpu().numpy()
|
|
248
|
+
past_kv[f"past_key_values.{i}.value"] = past_v.detach().cpu().numpy()
|
|
249
|
+
return past_kv
|
|
250
|
+
|
|
251
|
+
|
|
252
|
+
# Format PyTorch inputs to ONNX Runtime inputs
|
|
253
|
+
def convert_inputs_for_ort(
|
|
254
|
+
pt_inputs: dict,
|
|
255
|
+
use_buffer_share: bool = False,
|
|
256
|
+
past_seq_len: int = 0,
|
|
257
|
+
max_seq_len: int = 2048,
|
|
258
|
+
):
|
|
259
|
+
ort_inputs = {}
|
|
260
|
+
for k, v in pt_inputs.items():
|
|
261
|
+
if isinstance(v, np.ndarray):
|
|
262
|
+
ort_inputs[k] = v
|
|
263
|
+
elif k == "past_key_values":
|
|
264
|
+
ort_inputs.update(flatten_past_kv_inputs(v))
|
|
265
|
+
else:
|
|
266
|
+
ort_inputs[k] = v.detach().cpu().numpy()
|
|
267
|
+
|
|
268
|
+
# Reshape KV caches if using past-present-share-buffer
|
|
269
|
+
if use_buffer_share:
|
|
270
|
+
ort_inputs = enable_past_present_share_buffer(ort_inputs, past_seq_len, max_seq_len)
|
|
271
|
+
|
|
272
|
+
return ort_inputs
|
|
273
|
+
|
|
274
|
+
|
|
275
|
+
# Re-allocate KV caches from (batch_size, num_heads, past_sequence_length, head_size) to
|
|
276
|
+
# (batch_size, num_heads, max_sequence_length, head_size) for past-present buffer sharing
|
|
277
|
+
def enable_past_present_share_buffer(ort_inputs: dict, past_seq_len: int, max_seq_len: int):
|
|
278
|
+
for k, v in ort_inputs.items():
|
|
279
|
+
# Allocate new buffers with max_sequence_length for GQA
|
|
280
|
+
if "cache" in k or "past_key_values" in k:
|
|
281
|
+
# Copy v (BxSxPxH) into new_v (BxSxMxH)
|
|
282
|
+
batch_size, num_heads, _, head_size = v.shape
|
|
283
|
+
new_v = np.zeros((batch_size, num_heads, max_seq_len, head_size), dtype=v.dtype)
|
|
284
|
+
new_v[:batch_size, :num_heads, :past_seq_len, :head_size] = v
|
|
285
|
+
ort_inputs[k] = new_v
|
|
286
|
+
return ort_inputs
|
|
287
|
+
|
|
288
|
+
|
|
289
|
+
# Verify ONNX Runtime inputs with model
|
|
290
|
+
def verify_ort_inputs(model: InferenceSession, ort_inputs: dict):
|
|
291
|
+
# Check that all model inputs will be provided
|
|
292
|
+
model_inputs = set(map(lambda model_input: model_input.name, model.get_inputs()))
|
|
293
|
+
user_inputs = set(ort_inputs.keys())
|
|
294
|
+
missing_inputs = model_inputs - user_inputs
|
|
295
|
+
if len(missing_inputs):
|
|
296
|
+
print(f"The following model inputs are missing: {missing_inputs}")
|
|
297
|
+
raise Exception("There are missing inputs to the model. Please add them and try again.")
|
|
298
|
+
|
|
299
|
+
# Remove unnecessary inputs from model inputs
|
|
300
|
+
unnecessary_inputs = user_inputs - model_inputs
|
|
301
|
+
if len(unnecessary_inputs):
|
|
302
|
+
for unnecessary_input in unnecessary_inputs:
|
|
303
|
+
del ort_inputs[unnecessary_input]
|
|
304
|
+
|
|
305
|
+
return ort_inputs
|
|
306
|
+
|
|
307
|
+
|
|
308
|
+
# Add IO bindings for execution providers using OrtValue
|
|
309
|
+
# Use when you need to run inference once or twice to save memory
|
|
310
|
+
def add_io_bindings_as_ortvalues(
|
|
311
|
+
model: InferenceSession,
|
|
312
|
+
ort_inputs: dict,
|
|
313
|
+
device: str,
|
|
314
|
+
device_id: int,
|
|
315
|
+
use_buffer_share: bool,
|
|
316
|
+
kv_cache_ortvalues: dict,
|
|
317
|
+
):
|
|
318
|
+
io_binding = model.io_binding()
|
|
319
|
+
|
|
320
|
+
model_inputs = set(map(lambda i: i.name, model.get_inputs()))
|
|
321
|
+
for k, v in ort_inputs.items():
|
|
322
|
+
# Use this check to handle scenarios such as INT4 CUDA and FP16 CUDA models with
|
|
323
|
+
# GQA + RotaryEmbedding fusion where `position_ids` is removed as an ONNX model input
|
|
324
|
+
# but `position_ids` is used as a PyTorch model input
|
|
325
|
+
if k not in model_inputs:
|
|
326
|
+
continue
|
|
327
|
+
|
|
328
|
+
# Bind OrtValue inputs to device
|
|
329
|
+
if use_buffer_share and ("cache" in k or "past_key_values" in k):
|
|
330
|
+
if k not in kv_cache_ortvalues:
|
|
331
|
+
v_device = OrtValue.ortvalue_from_numpy(v, device_type=device, device_id=device_id)
|
|
332
|
+
io_binding.bind_ortvalue_input(k, v_device)
|
|
333
|
+
kv_cache_ortvalues[k] = v_device
|
|
334
|
+
else:
|
|
335
|
+
kv_cache_ortvalues[k].update_inplace(v)
|
|
336
|
+
io_binding.bind_ortvalue_input(k, kv_cache_ortvalues[k])
|
|
337
|
+
else:
|
|
338
|
+
v_device = OrtValue.ortvalue_from_numpy(v, device_type=device, device_id=device_id)
|
|
339
|
+
io_binding.bind_ortvalue_input(k, v_device)
|
|
340
|
+
|
|
341
|
+
for output in model.get_outputs():
|
|
342
|
+
name = output.name
|
|
343
|
+
if use_buffer_share and ("out" in name or "present" in name):
|
|
344
|
+
# Bind present KV cache outputs to past KV cache inputs in order to buffer share
|
|
345
|
+
input_name = name.replace("out", "cache").replace("present", "past_key_values")
|
|
346
|
+
io_binding.bind_ortvalue_output(name, kv_cache_ortvalues[input_name])
|
|
347
|
+
else:
|
|
348
|
+
io_binding.bind_output(name, device_type=device, device_id=device_id)
|
|
349
|
+
|
|
350
|
+
return io_binding, kv_cache_ortvalues
|
|
351
|
+
|
|
352
|
+
|
|
353
|
+
# Add IO bindings for execution providers using PyTorch tensors
|
|
354
|
+
# Use when you need to run inference many times
|
|
355
|
+
def add_io_bindings_as_tensors(
|
|
356
|
+
model: InferenceSession, inputs: dict, outputs: dict, use_fp16: bool, use_buffer_share: bool
|
|
357
|
+
):
|
|
358
|
+
# Verify model inputs
|
|
359
|
+
inputs = verify_ort_inputs(model, inputs)
|
|
360
|
+
|
|
361
|
+
device = None
|
|
362
|
+
pt_to_np = {
|
|
363
|
+
"torch.int32": np.int32,
|
|
364
|
+
"torch.int64": np.int64,
|
|
365
|
+
"torch.float16": np.float16,
|
|
366
|
+
"torch.float32": np.float32,
|
|
367
|
+
}
|
|
368
|
+
|
|
369
|
+
# Bind inputs/outputs to IO binding
|
|
370
|
+
io_binding = model.io_binding()
|
|
371
|
+
for k, v in inputs.items():
|
|
372
|
+
io_binding.bind_input(
|
|
373
|
+
name=k,
|
|
374
|
+
device_type=v.device.type,
|
|
375
|
+
device_id=0 if v.device.type == "cpu" else v.device.index,
|
|
376
|
+
element_type=pt_to_np[repr(v.dtype)],
|
|
377
|
+
shape=tuple(v.shape),
|
|
378
|
+
buffer_ptr=v.data_ptr(),
|
|
379
|
+
)
|
|
380
|
+
device = v.device
|
|
381
|
+
|
|
382
|
+
for output in model.get_outputs():
|
|
383
|
+
name = output.name
|
|
384
|
+
# Bind KV cache outputs to KV cache inputs
|
|
385
|
+
v = (
|
|
386
|
+
inputs[name.replace("present", "past_key_values")]
|
|
387
|
+
if use_buffer_share and "present" in name
|
|
388
|
+
else outputs[name]
|
|
389
|
+
)
|
|
390
|
+
io_binding.bind_output(
|
|
391
|
+
name=name,
|
|
392
|
+
device_type=device.type,
|
|
393
|
+
device_id=0 if device.type == "cpu" else device.index,
|
|
394
|
+
element_type=(np.float16 if use_fp16 else np.float32),
|
|
395
|
+
shape=tuple(v.shape),
|
|
396
|
+
buffer_ptr=v.data_ptr(),
|
|
397
|
+
)
|
|
398
|
+
|
|
399
|
+
return io_binding
|
|
400
|
+
|
|
401
|
+
|
|
402
|
+
# Get actual inputs when using real data (instead of sample data) and initialize outputs
|
|
403
|
+
def get_initial_inputs_and_outputs(
|
|
404
|
+
config: AutoConfig,
|
|
405
|
+
tokenizer: AutoTokenizer,
|
|
406
|
+
requested_length: int,
|
|
407
|
+
prompt: list[str],
|
|
408
|
+
device: torch.device,
|
|
409
|
+
use_fp16: bool,
|
|
410
|
+
use_buffer_share: bool,
|
|
411
|
+
engine: str,
|
|
412
|
+
):
|
|
413
|
+
tokenizer.pad_token = tokenizer.eos_token
|
|
414
|
+
encodings_dict = tokenizer.batch_encode_plus(prompt, padding=True)
|
|
415
|
+
torch_dtype = torch.float16 if use_fp16 else torch.float32
|
|
416
|
+
|
|
417
|
+
# input_ids: pad token id is 0
|
|
418
|
+
# attention_mask: pad token id is 0
|
|
419
|
+
# position_ids: pad token id is 1
|
|
420
|
+
input_ids = torch.tensor(encodings_dict["input_ids"], device=device, dtype=torch.int64)
|
|
421
|
+
attention_mask = torch.tensor(encodings_dict["attention_mask"], device=device, dtype=torch.int64)
|
|
422
|
+
position_ids = get_position_ids(attention_mask, use_past_kv=False)
|
|
423
|
+
|
|
424
|
+
# Check if tokenized prompt length matches the requested prompt length
|
|
425
|
+
tokenized_length = input_ids.shape[-1]
|
|
426
|
+
if tokenized_length > requested_length:
|
|
427
|
+
# Shorten the inputs from (batch_size, tokenized_length) to (batch_size, requested_length)
|
|
428
|
+
input_ids = input_ids[:, :requested_length]
|
|
429
|
+
attention_mask = attention_mask[:, :requested_length]
|
|
430
|
+
position_ids = get_position_ids(attention_mask, use_past_kv=False)
|
|
431
|
+
elif tokenized_length < requested_length:
|
|
432
|
+
# Lengthen the inputs from (batch_size, tokenized_length) to (batch_size, requested_length)
|
|
433
|
+
input_ids_first_col = input_ids[:, 0].unsqueeze(0).T
|
|
434
|
+
attention_mask_first_col = attention_mask[:, 0].unsqueeze(0).T
|
|
435
|
+
for _ in range(requested_length - tokenized_length):
|
|
436
|
+
input_ids = torch.hstack((input_ids_first_col, input_ids))
|
|
437
|
+
attention_mask = torch.hstack((attention_mask_first_col, attention_mask))
|
|
438
|
+
position_ids = get_position_ids(attention_mask, use_past_kv=False)
|
|
439
|
+
|
|
440
|
+
tokenized_length = input_ids.shape[-1]
|
|
441
|
+
assert tokenized_length == requested_length
|
|
442
|
+
|
|
443
|
+
# Create inputs
|
|
444
|
+
inputs = {
|
|
445
|
+
"input_ids": input_ids.contiguous() if engine == "ort" else input_ids,
|
|
446
|
+
"attention_mask": attention_mask.contiguous() if engine == "ort" else attention_mask,
|
|
447
|
+
"position_ids": position_ids.contiguous() if engine == "ort" else position_ids,
|
|
448
|
+
}
|
|
449
|
+
if engine != "ort":
|
|
450
|
+
inputs["past_key_values"] = []
|
|
451
|
+
|
|
452
|
+
# Get shape of KV cache inputs
|
|
453
|
+
batch_size, sequence_length = input_ids.shape
|
|
454
|
+
max_sequence_length = config.max_position_embeddings
|
|
455
|
+
num_heads = config.num_key_value_heads
|
|
456
|
+
head_size = config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
|
|
457
|
+
|
|
458
|
+
# Create KV cache inputs
|
|
459
|
+
for i in range(config.num_hidden_layers):
|
|
460
|
+
past_key = torch.zeros(
|
|
461
|
+
batch_size,
|
|
462
|
+
num_heads,
|
|
463
|
+
max_sequence_length if use_buffer_share else 0,
|
|
464
|
+
head_size,
|
|
465
|
+
device=device,
|
|
466
|
+
dtype=torch_dtype,
|
|
467
|
+
)
|
|
468
|
+
past_value = torch.zeros(
|
|
469
|
+
batch_size,
|
|
470
|
+
num_heads,
|
|
471
|
+
max_sequence_length if use_buffer_share else 0,
|
|
472
|
+
head_size,
|
|
473
|
+
device=device,
|
|
474
|
+
dtype=torch_dtype,
|
|
475
|
+
)
|
|
476
|
+
if engine == "ort":
|
|
477
|
+
inputs.update(
|
|
478
|
+
{
|
|
479
|
+
f"past_key_values.{i}.key": past_key.contiguous(),
|
|
480
|
+
f"past_key_values.{i}.value": past_value.contiguous(),
|
|
481
|
+
}
|
|
482
|
+
)
|
|
483
|
+
else:
|
|
484
|
+
inputs["past_key_values"].append((past_key, past_value))
|
|
485
|
+
|
|
486
|
+
outputs = None
|
|
487
|
+
if engine == "ort":
|
|
488
|
+
# Create outputs
|
|
489
|
+
logits = torch.zeros(batch_size, sequence_length, config.vocab_size, device=device, dtype=torch_dtype)
|
|
490
|
+
outputs = {"logits": logits.contiguous()}
|
|
491
|
+
if not use_buffer_share:
|
|
492
|
+
for i in range(config.num_hidden_layers):
|
|
493
|
+
present_key = torch.zeros(
|
|
494
|
+
batch_size, num_heads, sequence_length, head_size, device=device, dtype=torch_dtype
|
|
495
|
+
)
|
|
496
|
+
present_value = torch.zeros(
|
|
497
|
+
batch_size, num_heads, sequence_length, head_size, device=device, dtype=torch_dtype
|
|
498
|
+
)
|
|
499
|
+
outputs.update(
|
|
500
|
+
{f"present.{i}.key": present_key.contiguous(), f"present.{i}.value": present_value.contiguous()}
|
|
501
|
+
)
|
|
502
|
+
|
|
503
|
+
return inputs, outputs
|