onnxruntime-directml 1.20.0__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxruntime/LICENSE +21 -0
- onnxruntime/Privacy.md +21 -0
- onnxruntime/ThirdPartyNotices.txt +6508 -0
- onnxruntime/__init__.py +78 -0
- onnxruntime/backend/__init__.py +6 -0
- onnxruntime/backend/backend.py +174 -0
- onnxruntime/backend/backend_rep.py +53 -0
- onnxruntime/capi/DirectML.dll +0 -0
- onnxruntime/capi/__init__.py +4 -0
- onnxruntime/capi/_ld_preload.py +7 -0
- onnxruntime/capi/_pybind_state.py +33 -0
- onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
- onnxruntime/capi/onnxruntime.dll +0 -0
- onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
- onnxruntime/capi/onnxruntime_inference_collection.py +1108 -0
- onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
- onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
- onnxruntime/capi/onnxruntime_validation.py +150 -0
- onnxruntime/capi/version_info.py +2 -0
- onnxruntime/datasets/__init__.py +17 -0
- onnxruntime/datasets/logreg_iris.onnx +0 -0
- onnxruntime/datasets/mul_1.onnx +0 -0
- onnxruntime/datasets/sigmoid.onnx +13 -0
- onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
- onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
- onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
- onnxruntime/quantization/__init__.py +16 -0
- onnxruntime/quantization/base_quantizer.py +532 -0
- onnxruntime/quantization/calibrate.py +1245 -0
- onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
- onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
- onnxruntime/quantization/execution_providers/qnn/preprocess.py +307 -0
- onnxruntime/quantization/execution_providers/qnn/quant_config.py +387 -0
- onnxruntime/quantization/fusions/__init__.py +3 -0
- onnxruntime/quantization/fusions/fusion.py +311 -0
- onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
- onnxruntime/quantization/fusions/fusion_layernorm.py +135 -0
- onnxruntime/quantization/matmul_4bits_quantizer.py +1480 -0
- onnxruntime/quantization/matmul_bnb4_quantizer.py +240 -0
- onnxruntime/quantization/onnx_model.py +580 -0
- onnxruntime/quantization/onnx_quantizer.py +1008 -0
- onnxruntime/quantization/operators/__init__.py +2 -0
- onnxruntime/quantization/operators/activation.py +119 -0
- onnxruntime/quantization/operators/argmax.py +18 -0
- onnxruntime/quantization/operators/attention.py +73 -0
- onnxruntime/quantization/operators/base_operator.py +26 -0
- onnxruntime/quantization/operators/binary_op.py +72 -0
- onnxruntime/quantization/operators/concat.py +62 -0
- onnxruntime/quantization/operators/conv.py +258 -0
- onnxruntime/quantization/operators/direct_q8.py +78 -0
- onnxruntime/quantization/operators/embed_layernorm.py +121 -0
- onnxruntime/quantization/operators/gather.py +64 -0
- onnxruntime/quantization/operators/gavgpool.py +62 -0
- onnxruntime/quantization/operators/gemm.py +166 -0
- onnxruntime/quantization/operators/lstm.py +117 -0
- onnxruntime/quantization/operators/matmul.py +231 -0
- onnxruntime/quantization/operators/maxpool.py +34 -0
- onnxruntime/quantization/operators/norm.py +40 -0
- onnxruntime/quantization/operators/pad.py +100 -0
- onnxruntime/quantization/operators/pooling.py +67 -0
- onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
- onnxruntime/quantization/operators/resize.py +34 -0
- onnxruntime/quantization/operators/softmax.py +74 -0
- onnxruntime/quantization/operators/split.py +63 -0
- onnxruntime/quantization/operators/where.py +87 -0
- onnxruntime/quantization/preprocess.py +141 -0
- onnxruntime/quantization/qdq_loss_debug.py +389 -0
- onnxruntime/quantization/qdq_quantizer.py +1187 -0
- onnxruntime/quantization/quant_utils.py +891 -0
- onnxruntime/quantization/quantize.py +748 -0
- onnxruntime/quantization/registry.py +106 -0
- onnxruntime/quantization/shape_inference.py +187 -0
- onnxruntime/quantization/tensor_quant_overrides.py +516 -0
- onnxruntime/tools/__init__.py +10 -0
- onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
- onnxruntime/tools/convert_onnx_models_to_ort.py +377 -0
- onnxruntime/tools/file_utils.py +46 -0
- onnxruntime/tools/logger.py +11 -0
- onnxruntime/tools/make_dynamic_shape_fixed.py +72 -0
- onnxruntime/tools/mobile_helpers/__init__.py +0 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +33 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
- onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
- onnxruntime/tools/mobile_helpers/usability_checker.py +739 -0
- onnxruntime/tools/offline_tuning.py +169 -0
- onnxruntime/tools/onnx_model_utils.py +413 -0
- onnxruntime/tools/onnx_randomizer.py +85 -0
- onnxruntime/tools/onnxruntime_test.py +164 -0
- onnxruntime/tools/optimize_onnx_model.py +55 -0
- onnxruntime/tools/ort_format_model/__init__.py +25 -0
- onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +663 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
- onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
- onnxruntime/tools/ort_format_model/types.py +84 -0
- onnxruntime/tools/ort_format_model/utils.py +62 -0
- onnxruntime/tools/pytorch_export_contrib_ops.py +108 -0
- onnxruntime/tools/pytorch_export_helpers.py +131 -0
- onnxruntime/tools/qdq_helpers/__init__.py +0 -0
- onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
- onnxruntime/tools/reduced_build_config_parser.py +202 -0
- onnxruntime/tools/symbolic_shape_infer.py +3016 -0
- onnxruntime/tools/update_onnx_opset.py +31 -0
- onnxruntime/transformers/__init__.py +8 -0
- onnxruntime/transformers/affinity_helper.py +40 -0
- onnxruntime/transformers/benchmark.py +944 -0
- onnxruntime/transformers/benchmark_helper.py +646 -0
- onnxruntime/transformers/bert_perf_test.py +634 -0
- onnxruntime/transformers/bert_test_data.py +642 -0
- onnxruntime/transformers/compare_bert_results.py +246 -0
- onnxruntime/transformers/constants.py +47 -0
- onnxruntime/transformers/convert_generation.py +3124 -0
- onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
- onnxruntime/transformers/convert_to_packing_mode.py +387 -0
- onnxruntime/transformers/dynamo_onnx_helper.py +104 -0
- onnxruntime/transformers/float16.py +501 -0
- onnxruntime/transformers/fusion_attention.py +1235 -0
- onnxruntime/transformers/fusion_attention_clip.py +257 -0
- onnxruntime/transformers/fusion_attention_sam2.py +534 -0
- onnxruntime/transformers/fusion_attention_unet.py +1304 -0
- onnxruntime/transformers/fusion_attention_vae.py +301 -0
- onnxruntime/transformers/fusion_bart_attention.py +640 -0
- onnxruntime/transformers/fusion_base.py +137 -0
- onnxruntime/transformers/fusion_bias_add.py +58 -0
- onnxruntime/transformers/fusion_biasgelu.py +66 -0
- onnxruntime/transformers/fusion_biassplitgelu.py +111 -0
- onnxruntime/transformers/fusion_conformer_attention.py +143 -0
- onnxruntime/transformers/fusion_embedlayer.py +811 -0
- onnxruntime/transformers/fusion_fastgelu.py +360 -0
- onnxruntime/transformers/fusion_gelu.py +259 -0
- onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
- onnxruntime/transformers/fusion_gemmfastgelu.py +122 -0
- onnxruntime/transformers/fusion_gpt_attention.py +546 -0
- onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
- onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
- onnxruntime/transformers/fusion_group_norm.py +179 -0
- onnxruntime/transformers/fusion_layernorm.py +465 -0
- onnxruntime/transformers/fusion_nhwc_conv.py +100 -0
- onnxruntime/transformers/fusion_options.py +340 -0
- onnxruntime/transformers/fusion_qordered_attention.py +421 -0
- onnxruntime/transformers/fusion_qordered_gelu.py +119 -0
- onnxruntime/transformers/fusion_qordered_layernorm.py +123 -0
- onnxruntime/transformers/fusion_qordered_matmul.py +217 -0
- onnxruntime/transformers/fusion_quickgelu.py +74 -0
- onnxruntime/transformers/fusion_reshape.py +173 -0
- onnxruntime/transformers/fusion_rotary_attention.py +1592 -0
- onnxruntime/transformers/fusion_shape.py +110 -0
- onnxruntime/transformers/fusion_simplified_layernorm.py +159 -0
- onnxruntime/transformers/fusion_skip_group_norm.py +255 -0
- onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
- onnxruntime/transformers/fusion_transpose.py +168 -0
- onnxruntime/transformers/fusion_utils.py +307 -0
- onnxruntime/transformers/huggingface_models.py +167 -0
- onnxruntime/transformers/import_utils.py +20 -0
- onnxruntime/transformers/io_binding_helper.py +442 -0
- onnxruntime/transformers/large_model_exporter.py +395 -0
- onnxruntime/transformers/machine_info.py +221 -0
- onnxruntime/transformers/metrics.py +164 -0
- onnxruntime/transformers/models/bart/__init__.py +12 -0
- onnxruntime/transformers/models/bart/export.py +98 -0
- onnxruntime/transformers/models/bert/__init__.py +12 -0
- onnxruntime/transformers/models/bert/eval_squad.py +329 -0
- onnxruntime/transformers/models/gpt2/__init__.py +12 -0
- onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
- onnxruntime/transformers/models/gpt2/convert_to_onnx.py +561 -0
- onnxruntime/transformers/models/gpt2/gpt2_helper.py +1032 -0
- onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
- onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
- onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
- onnxruntime/transformers/models/llama/__init__.py +12 -0
- onnxruntime/transformers/models/llama/benchmark.py +703 -0
- onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
- onnxruntime/transformers/models/llama/benchmark_e2e.py +606 -0
- onnxruntime/transformers/models/llama/convert_to_onnx.py +1027 -0
- onnxruntime/transformers/models/llama/dist_settings.py +57 -0
- onnxruntime/transformers/models/llama/llama_inputs.py +503 -0
- onnxruntime/transformers/models/llama/llama_parity.py +309 -0
- onnxruntime/transformers/models/llama/llama_torch.py +47 -0
- onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
- onnxruntime/transformers/models/longformer/__init__.py +12 -0
- onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
- onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
- onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
- onnxruntime/transformers/models/longformer/longformer_helper.py +77 -0
- onnxruntime/transformers/models/phi2/__init__.py +12 -0
- onnxruntime/transformers/models/phi2/convert_to_onnx.py +576 -0
- onnxruntime/transformers/models/phi2/inference_example.py +414 -0
- onnxruntime/transformers/models/sam2/__init__.py +12 -0
- onnxruntime/transformers/models/sam2/benchmark_sam2.py +625 -0
- onnxruntime/transformers/models/sam2/convert_to_onnx.py +260 -0
- onnxruntime/transformers/models/sam2/image_decoder.py +273 -0
- onnxruntime/transformers/models/sam2/image_encoder.py +186 -0
- onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
- onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
- onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
- onnxruntime/transformers/models/sam2/sam2_demo.py +322 -0
- onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +280 -0
- onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
- onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark.py +1429 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +102 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +268 -0
- onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1319 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1181 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder.py +296 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +388 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
- onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +350 -0
- onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
- onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
- onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
- onnxruntime/transformers/models/t5/__init__.py +12 -0
- onnxruntime/transformers/models/t5/convert_to_onnx.py +278 -0
- onnxruntime/transformers/models/t5/past_helper.py +150 -0
- onnxruntime/transformers/models/t5/t5_decoder.py +438 -0
- onnxruntime/transformers/models/t5/t5_encoder.py +171 -0
- onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +299 -0
- onnxruntime/transformers/models/t5/t5_helper.py +272 -0
- onnxruntime/transformers/models/whisper/__init__.py +12 -0
- onnxruntime/transformers/models/whisper/benchmark.py +610 -0
- onnxruntime/transformers/models/whisper/benchmark_all.py +528 -0
- onnxruntime/transformers/models/whisper/convert_to_onnx.py +536 -0
- onnxruntime/transformers/models/whisper/whisper_chain.py +329 -0
- onnxruntime/transformers/models/whisper/whisper_decoder.py +402 -0
- onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
- onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +306 -0
- onnxruntime/transformers/models/whisper/whisper_helper.py +524 -0
- onnxruntime/transformers/models/whisper/whisper_openai_helper.py +84 -0
- onnxruntime/transformers/onnx_exporter.py +717 -0
- onnxruntime/transformers/onnx_model.py +1569 -0
- onnxruntime/transformers/onnx_model_bart.py +142 -0
- onnxruntime/transformers/onnx_model_bert.py +481 -0
- onnxruntime/transformers/onnx_model_bert_keras.py +475 -0
- onnxruntime/transformers/onnx_model_bert_tf.py +589 -0
- onnxruntime/transformers/onnx_model_clip.py +40 -0
- onnxruntime/transformers/onnx_model_conformer.py +33 -0
- onnxruntime/transformers/onnx_model_gpt2.py +101 -0
- onnxruntime/transformers/onnx_model_phi.py +930 -0
- onnxruntime/transformers/onnx_model_sam2.py +138 -0
- onnxruntime/transformers/onnx_model_t5.py +791 -0
- onnxruntime/transformers/onnx_model_tnlr.py +227 -0
- onnxruntime/transformers/onnx_model_unet.py +259 -0
- onnxruntime/transformers/onnx_model_vae.py +43 -0
- onnxruntime/transformers/onnx_utils.py +55 -0
- onnxruntime/transformers/optimizer.py +612 -0
- onnxruntime/transformers/profiler.py +725 -0
- onnxruntime/transformers/quantize_helper.py +76 -0
- onnxruntime/transformers/shape_infer_helper.py +122 -0
- onnxruntime/transformers/shape_optimizer.py +401 -0
- onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
- onnxruntime_directml-1.20.0.dist-info/METADATA +187 -0
- onnxruntime_directml-1.20.0.dist-info/RECORD +305 -0
- onnxruntime_directml-1.20.0.dist-info/WHEEL +5 -0
- onnxruntime_directml-1.20.0.dist-info/entry_points.txt +2 -0
- onnxruntime_directml-1.20.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,646 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License. See License.txt in the project root for
|
|
4
|
+
# license information.
|
|
5
|
+
# --------------------------------------------------------------------------
|
|
6
|
+
|
|
7
|
+
import csv
|
|
8
|
+
import logging
|
|
9
|
+
import os
|
|
10
|
+
import random
|
|
11
|
+
import sys
|
|
12
|
+
import time
|
|
13
|
+
import timeit
|
|
14
|
+
from abc import ABC, abstractmethod
|
|
15
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
16
|
+
from datetime import datetime
|
|
17
|
+
from enum import Enum
|
|
18
|
+
from time import sleep
|
|
19
|
+
from typing import Any, Dict, List, Optional
|
|
20
|
+
|
|
21
|
+
import coloredlogs
|
|
22
|
+
import numpy
|
|
23
|
+
import torch
|
|
24
|
+
import transformers
|
|
25
|
+
from packaging import version
|
|
26
|
+
|
|
27
|
+
import onnxruntime
|
|
28
|
+
|
|
29
|
+
logger = logging.getLogger(__name__)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class Precision(Enum):
|
|
33
|
+
FLOAT32 = "fp32"
|
|
34
|
+
FLOAT16 = "fp16"
|
|
35
|
+
INT8 = "int8"
|
|
36
|
+
INT4 = "int4"
|
|
37
|
+
|
|
38
|
+
def __str__(self):
|
|
39
|
+
return self.value
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class OptimizerInfo(Enum):
|
|
43
|
+
# no_opt means using the raw ONNX model, but OnnxRuntime might still apply optimization as long as
|
|
44
|
+
# graph optimization level is not 0 (disable all).
|
|
45
|
+
NOOPT = "no_opt"
|
|
46
|
+
BYORT = "by_ort"
|
|
47
|
+
BYSCRIPT = "by_script"
|
|
48
|
+
|
|
49
|
+
def __str__(self):
|
|
50
|
+
return self.value
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class ConfigModifier:
|
|
54
|
+
def __init__(self, num_layers):
|
|
55
|
+
self.num_layers = num_layers
|
|
56
|
+
|
|
57
|
+
def modify(self, config):
|
|
58
|
+
if self.num_layers is None:
|
|
59
|
+
return
|
|
60
|
+
if hasattr(config, "num_hidden_layers"):
|
|
61
|
+
config.num_hidden_layers = self.num_layers
|
|
62
|
+
logger.info(f"Modifying pytorch model's number of hidden layers to: {self.num_layers}")
|
|
63
|
+
if hasattr(config, "encoder_layers"):
|
|
64
|
+
config.encoder_layers = self.num_layers
|
|
65
|
+
logger.info(f"Modifying pytorch model's number of encoder layers to: {self.num_layers}")
|
|
66
|
+
if hasattr(config, "decoder_layers "):
|
|
67
|
+
config.decoder_layers = self.num_layers
|
|
68
|
+
logger.info(f"Modifying pytorch model's number of decoder layers to: {self.num_layers}")
|
|
69
|
+
|
|
70
|
+
def get_layer_num(self):
|
|
71
|
+
return self.num_layers
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
IO_BINDING_DATA_TYPE_MAP = {
|
|
75
|
+
"float32": numpy.float32,
|
|
76
|
+
# TODO: Add more.
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def create_onnxruntime_session(
|
|
81
|
+
onnx_model_path,
|
|
82
|
+
use_gpu,
|
|
83
|
+
provider=None,
|
|
84
|
+
enable_all_optimization=True,
|
|
85
|
+
num_threads=-1,
|
|
86
|
+
enable_profiling=False,
|
|
87
|
+
verbose=False,
|
|
88
|
+
enable_mlas_gemm_fastmath_arm64_bfloat16=False,
|
|
89
|
+
provider_options={}, # map execution provider name to its option # noqa: B006
|
|
90
|
+
):
|
|
91
|
+
session = None
|
|
92
|
+
try:
|
|
93
|
+
sess_options = onnxruntime.SessionOptions()
|
|
94
|
+
|
|
95
|
+
if enable_all_optimization:
|
|
96
|
+
sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
|
97
|
+
else:
|
|
98
|
+
sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_BASIC
|
|
99
|
+
|
|
100
|
+
if enable_profiling:
|
|
101
|
+
sess_options.enable_profiling = True
|
|
102
|
+
|
|
103
|
+
if num_threads > 0:
|
|
104
|
+
sess_options.intra_op_num_threads = num_threads
|
|
105
|
+
logger.debug(f"Session option: intra_op_num_threads={sess_options.intra_op_num_threads}")
|
|
106
|
+
|
|
107
|
+
if verbose:
|
|
108
|
+
sess_options.log_severity_level = 0
|
|
109
|
+
else:
|
|
110
|
+
sess_options.log_severity_level = 4
|
|
111
|
+
|
|
112
|
+
logger.debug(f"Create session for onnx model: {onnx_model_path}")
|
|
113
|
+
if use_gpu:
|
|
114
|
+
if provider == "dml":
|
|
115
|
+
providers = ["DmlExecutionProvider", "CPUExecutionProvider"]
|
|
116
|
+
elif provider == "rocm":
|
|
117
|
+
providers = ["ROCMExecutionProvider", "CPUExecutionProvider"]
|
|
118
|
+
elif provider == "migraphx":
|
|
119
|
+
providers = [
|
|
120
|
+
"MIGraphXExecutionProvider",
|
|
121
|
+
"ROCMExecutionProvider",
|
|
122
|
+
"CPUExecutionProvider",
|
|
123
|
+
]
|
|
124
|
+
elif provider == "cuda":
|
|
125
|
+
providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
|
|
126
|
+
elif provider == "tensorrt":
|
|
127
|
+
providers = [
|
|
128
|
+
"TensorrtExecutionProvider",
|
|
129
|
+
"CUDAExecutionProvider",
|
|
130
|
+
"CPUExecutionProvider",
|
|
131
|
+
]
|
|
132
|
+
else:
|
|
133
|
+
providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
|
|
134
|
+
else:
|
|
135
|
+
providers = ["CPUExecutionProvider"]
|
|
136
|
+
|
|
137
|
+
if provider_options:
|
|
138
|
+
providers = [(name, provider_options[name]) if name in provider_options else name for name in providers]
|
|
139
|
+
|
|
140
|
+
if enable_mlas_gemm_fastmath_arm64_bfloat16:
|
|
141
|
+
sess_options.add_session_config_entry("mlas.enable_gemm_fastmath_arm64_bfloat16", "1")
|
|
142
|
+
|
|
143
|
+
session = onnxruntime.InferenceSession(onnx_model_path, sess_options, providers=providers)
|
|
144
|
+
except Exception:
|
|
145
|
+
logger.error("Exception", exc_info=True) # noqa: G201
|
|
146
|
+
|
|
147
|
+
return session
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
def setup_logger(verbose=True):
|
|
151
|
+
if verbose:
|
|
152
|
+
coloredlogs.install(
|
|
153
|
+
level="DEBUG",
|
|
154
|
+
fmt="[%(filename)s:%(lineno)s - %(funcName)20s()] %(message)s",
|
|
155
|
+
)
|
|
156
|
+
else:
|
|
157
|
+
coloredlogs.install(fmt="%(message)s")
|
|
158
|
+
logging.getLogger("transformers").setLevel(logging.WARNING)
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
def prepare_environment(cache_dir, output_dir, use_gpu, provider=None):
|
|
162
|
+
if cache_dir and not os.path.exists(cache_dir):
|
|
163
|
+
os.makedirs(cache_dir)
|
|
164
|
+
|
|
165
|
+
if output_dir and not os.path.exists(output_dir):
|
|
166
|
+
os.makedirs(output_dir)
|
|
167
|
+
|
|
168
|
+
if use_gpu:
|
|
169
|
+
if provider == "dml":
|
|
170
|
+
assert (
|
|
171
|
+
"DmlExecutionProvider" in onnxruntime.get_available_providers()
|
|
172
|
+
), "Please install onnxruntime-directml package to test GPU inference."
|
|
173
|
+
|
|
174
|
+
else:
|
|
175
|
+
assert not set(onnxruntime.get_available_providers()).isdisjoint(
|
|
176
|
+
["CUDAExecutionProvider", "ROCMExecutionProvider", "MIGraphXExecutionProvider"]
|
|
177
|
+
), "Please install onnxruntime-gpu package, or install ROCm support, to test GPU inference."
|
|
178
|
+
|
|
179
|
+
logger.info(f"PyTorch Version:{torch.__version__}")
|
|
180
|
+
logger.info(f"Transformers Version:{transformers.__version__}")
|
|
181
|
+
logger.info(f"OnnxRuntime Version:{onnxruntime.__version__}")
|
|
182
|
+
|
|
183
|
+
# Support three major versions of PyTorch and OnnxRuntime, and up to 9 months of transformers.
|
|
184
|
+
assert version.parse(torch.__version__) >= version.parse("1.10.0")
|
|
185
|
+
assert version.parse(transformers.__version__) >= version.parse("4.12.0")
|
|
186
|
+
assert version.parse(onnxruntime.__version__) >= version.parse("1.10.0")
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
def get_latency_result(latency_list, batch_size):
|
|
190
|
+
latency_ms = sum(latency_list) / float(len(latency_list)) * 1000.0
|
|
191
|
+
latency_variance = numpy.var(latency_list, dtype=numpy.float64) * 1000.0
|
|
192
|
+
throughput = batch_size * (1000.0 / latency_ms)
|
|
193
|
+
|
|
194
|
+
return {
|
|
195
|
+
"test_times": len(latency_list),
|
|
196
|
+
"latency_variance": f"{latency_variance:.2f}",
|
|
197
|
+
"latency_90_percentile": f"{numpy.percentile(latency_list, 90) * 1000.0:.2f}",
|
|
198
|
+
"latency_95_percentile": f"{numpy.percentile(latency_list, 95) * 1000.0:.2f}",
|
|
199
|
+
"latency_99_percentile": f"{numpy.percentile(latency_list, 99) * 1000.0:.2f}",
|
|
200
|
+
"average_latency_ms": f"{latency_ms:.2f}",
|
|
201
|
+
"QPS": f"{throughput:.2f}",
|
|
202
|
+
}
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
def output_details(results, csv_filename):
|
|
206
|
+
with open(csv_filename, mode="a", newline="", encoding="ascii") as csv_file:
|
|
207
|
+
column_names = [
|
|
208
|
+
"engine",
|
|
209
|
+
"version",
|
|
210
|
+
"providers",
|
|
211
|
+
"device",
|
|
212
|
+
"precision",
|
|
213
|
+
"optimizer",
|
|
214
|
+
"io_binding",
|
|
215
|
+
"model_name",
|
|
216
|
+
"inputs",
|
|
217
|
+
"threads",
|
|
218
|
+
"batch_size",
|
|
219
|
+
"sequence_length",
|
|
220
|
+
"custom_layer_num",
|
|
221
|
+
"datetime",
|
|
222
|
+
"test_times",
|
|
223
|
+
"QPS",
|
|
224
|
+
"average_latency_ms",
|
|
225
|
+
"latency_variance",
|
|
226
|
+
"latency_90_percentile",
|
|
227
|
+
"latency_95_percentile",
|
|
228
|
+
"latency_99_percentile",
|
|
229
|
+
]
|
|
230
|
+
|
|
231
|
+
csv_writer = csv.DictWriter(csv_file, fieldnames=column_names)
|
|
232
|
+
csv_writer.writeheader()
|
|
233
|
+
for result in results:
|
|
234
|
+
csv_writer.writerow(result)
|
|
235
|
+
|
|
236
|
+
logger.info(f"Detail results are saved to csv file: {csv_filename}")
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
def output_summary(results, csv_filename, args):
|
|
240
|
+
with open(csv_filename, mode="a", newline="", encoding="ascii") as csv_file:
|
|
241
|
+
header_names = [
|
|
242
|
+
"model_name",
|
|
243
|
+
"inputs",
|
|
244
|
+
"custom_layer_num",
|
|
245
|
+
"engine",
|
|
246
|
+
"version",
|
|
247
|
+
"providers",
|
|
248
|
+
"device",
|
|
249
|
+
"precision",
|
|
250
|
+
"optimizer",
|
|
251
|
+
"io_binding",
|
|
252
|
+
"threads",
|
|
253
|
+
]
|
|
254
|
+
data_names = []
|
|
255
|
+
for batch_size in args.batch_sizes:
|
|
256
|
+
if args.sequence_lengths == [""]:
|
|
257
|
+
data_names.append(f"b{batch_size}")
|
|
258
|
+
else:
|
|
259
|
+
for sequence_length in args.sequence_lengths:
|
|
260
|
+
data_names.append(f"b{batch_size}_s{sequence_length}")
|
|
261
|
+
|
|
262
|
+
csv_writer = csv.DictWriter(csv_file, fieldnames=header_names + data_names)
|
|
263
|
+
csv_writer.writeheader()
|
|
264
|
+
for model_name in args.models:
|
|
265
|
+
for input_count in [1, 2, 3]:
|
|
266
|
+
for engine_name in args.engines:
|
|
267
|
+
for io_binding in [True, False, ""]:
|
|
268
|
+
for threads in args.num_threads:
|
|
269
|
+
row = {}
|
|
270
|
+
for result in results:
|
|
271
|
+
if (
|
|
272
|
+
result["model_name"] == model_name
|
|
273
|
+
and result["inputs"] == input_count
|
|
274
|
+
and result["engine"] == engine_name
|
|
275
|
+
and result["io_binding"] == io_binding
|
|
276
|
+
and result["threads"] == threads
|
|
277
|
+
):
|
|
278
|
+
headers = {k: v for k, v in result.items() if k in header_names}
|
|
279
|
+
if not row:
|
|
280
|
+
row.update(headers)
|
|
281
|
+
row.update({k: "" for k in data_names})
|
|
282
|
+
else:
|
|
283
|
+
for k in header_names:
|
|
284
|
+
assert row[k] == headers[k]
|
|
285
|
+
b = result["batch_size"]
|
|
286
|
+
s = result["sequence_length"]
|
|
287
|
+
if s:
|
|
288
|
+
row[f"b{b}_s{s}"] = result["average_latency_ms"]
|
|
289
|
+
else:
|
|
290
|
+
row[f"b{b}"] = result["average_latency_ms"]
|
|
291
|
+
if row:
|
|
292
|
+
csv_writer.writerow(row)
|
|
293
|
+
|
|
294
|
+
logger.info(f"Summary results are saved to csv file: {csv_filename}")
|
|
295
|
+
|
|
296
|
+
|
|
297
|
+
def output_fusion_statistics(model_fusion_statistics, csv_filename):
|
|
298
|
+
with open(csv_filename, mode="a", newline="", encoding="ascii") as csv_file:
|
|
299
|
+
column_names = [
|
|
300
|
+
"model_filename",
|
|
301
|
+
"datetime",
|
|
302
|
+
"transformers",
|
|
303
|
+
"torch",
|
|
304
|
+
*list(next(iter(model_fusion_statistics.values())).keys()),
|
|
305
|
+
]
|
|
306
|
+
csv_writer = csv.DictWriter(csv_file, fieldnames=column_names)
|
|
307
|
+
csv_writer.writeheader()
|
|
308
|
+
for key in model_fusion_statistics:
|
|
309
|
+
model_fusion_statistics[key]["datetime"] = str(datetime.now())
|
|
310
|
+
model_fusion_statistics[key]["transformers"] = transformers.__version__
|
|
311
|
+
model_fusion_statistics[key]["torch"] = torch.__version__
|
|
312
|
+
model_fusion_statistics[key]["model_filename"] = key
|
|
313
|
+
csv_writer.writerow(model_fusion_statistics[key])
|
|
314
|
+
logger.info(f"Fusion statistics is saved to csv file: {csv_filename}")
|
|
315
|
+
|
|
316
|
+
|
|
317
|
+
def inference_ort(ort_session, ort_inputs, result_template, repeat_times, batch_size, warm_up_repeat=0):
|
|
318
|
+
result = {}
|
|
319
|
+
timeit.repeat(lambda: ort_session.run(None, ort_inputs), number=1, repeat=warm_up_repeat) # Dry run
|
|
320
|
+
latency_list = timeit.repeat(lambda: ort_session.run(None, ort_inputs), number=1, repeat=repeat_times)
|
|
321
|
+
result.update(result_template)
|
|
322
|
+
result.update({"io_binding": False})
|
|
323
|
+
result.update(get_latency_result(latency_list, batch_size))
|
|
324
|
+
return result
|
|
325
|
+
|
|
326
|
+
|
|
327
|
+
def inference_ort_with_io_binding(
|
|
328
|
+
ort_session,
|
|
329
|
+
ort_inputs,
|
|
330
|
+
result_template,
|
|
331
|
+
repeat_times,
|
|
332
|
+
ort_output_names,
|
|
333
|
+
ort_outputs,
|
|
334
|
+
output_buffers,
|
|
335
|
+
output_buffer_max_sizes,
|
|
336
|
+
batch_size,
|
|
337
|
+
device,
|
|
338
|
+
data_type=numpy.longlong,
|
|
339
|
+
warm_up_repeat=0,
|
|
340
|
+
):
|
|
341
|
+
result = {}
|
|
342
|
+
|
|
343
|
+
# Bind inputs and outputs to onnxruntime session
|
|
344
|
+
io_binding = ort_session.io_binding()
|
|
345
|
+
# Bind inputs to device
|
|
346
|
+
for name in ort_inputs:
|
|
347
|
+
np_input = torch.from_numpy(ort_inputs[name]).to(device)
|
|
348
|
+
input_type = IO_BINDING_DATA_TYPE_MAP.get(str(ort_inputs[name].dtype), data_type)
|
|
349
|
+
io_binding.bind_input(
|
|
350
|
+
name,
|
|
351
|
+
np_input.device.type,
|
|
352
|
+
0,
|
|
353
|
+
input_type,
|
|
354
|
+
np_input.shape,
|
|
355
|
+
np_input.data_ptr(),
|
|
356
|
+
)
|
|
357
|
+
# Bind outputs buffers with the sizes needed if not allocated already
|
|
358
|
+
if len(output_buffers) == 0:
|
|
359
|
+
allocateOutputBuffers(output_buffers, output_buffer_max_sizes, device)
|
|
360
|
+
|
|
361
|
+
for i, ort_output_name in enumerate(ort_output_names):
|
|
362
|
+
io_binding.bind_output(
|
|
363
|
+
ort_output_name,
|
|
364
|
+
output_buffers[i].device.type,
|
|
365
|
+
0,
|
|
366
|
+
numpy.float32,
|
|
367
|
+
ort_outputs[i].shape,
|
|
368
|
+
output_buffers[i].data_ptr(),
|
|
369
|
+
)
|
|
370
|
+
|
|
371
|
+
timeit.repeat(
|
|
372
|
+
lambda: ort_session.run_with_iobinding(io_binding),
|
|
373
|
+
number=1,
|
|
374
|
+
repeat=warm_up_repeat,
|
|
375
|
+
) # Dry run
|
|
376
|
+
|
|
377
|
+
latency_list = timeit.repeat(
|
|
378
|
+
lambda: ort_session.run_with_iobinding(io_binding),
|
|
379
|
+
number=1,
|
|
380
|
+
repeat=repeat_times,
|
|
381
|
+
)
|
|
382
|
+
result.update(result_template)
|
|
383
|
+
result.update({"io_binding": True})
|
|
384
|
+
result.update(get_latency_result(latency_list, batch_size))
|
|
385
|
+
return result
|
|
386
|
+
|
|
387
|
+
|
|
388
|
+
def allocateOutputBuffers(output_buffers, output_buffer_max_sizes, device): # noqa: N802
|
|
389
|
+
# Allocate output tensors with the largest test size needed. So the allocated memory can be reused
|
|
390
|
+
# for each test run.
|
|
391
|
+
|
|
392
|
+
for i in output_buffer_max_sizes:
|
|
393
|
+
output_buffers.append(torch.empty(i, dtype=torch.float32, device=device))
|
|
394
|
+
|
|
395
|
+
|
|
396
|
+
def set_random_seed(seed=123):
|
|
397
|
+
"""Set random seed manually to get deterministic results"""
|
|
398
|
+
random.seed(seed)
|
|
399
|
+
numpy.random.seed(seed)
|
|
400
|
+
torch.manual_seed(seed)
|
|
401
|
+
torch.cuda.manual_seed(seed)
|
|
402
|
+
torch.cuda.manual_seed_all(seed)
|
|
403
|
+
# torch.backends.cudnn.enabled = False
|
|
404
|
+
# torch.backends.cudnn.benchmark = False
|
|
405
|
+
# torch.backends.cudnn.deterministic = True
|
|
406
|
+
|
|
407
|
+
|
|
408
|
+
def get_gpu_info() -> Optional[List[Dict[str, Any]]]:
|
|
409
|
+
from py3nvml.py3nvml import (
|
|
410
|
+
NVMLError,
|
|
411
|
+
nvmlDeviceGetCount,
|
|
412
|
+
nvmlDeviceGetHandleByIndex,
|
|
413
|
+
nvmlDeviceGetMemoryInfo,
|
|
414
|
+
nvmlDeviceGetName,
|
|
415
|
+
nvmlInit,
|
|
416
|
+
nvmlShutdown,
|
|
417
|
+
)
|
|
418
|
+
|
|
419
|
+
try:
|
|
420
|
+
nvmlInit()
|
|
421
|
+
result = []
|
|
422
|
+
device_count = nvmlDeviceGetCount()
|
|
423
|
+
if not isinstance(device_count, int):
|
|
424
|
+
return None
|
|
425
|
+
|
|
426
|
+
for i in range(device_count):
|
|
427
|
+
info = nvmlDeviceGetMemoryInfo(nvmlDeviceGetHandleByIndex(i))
|
|
428
|
+
if isinstance(info, str):
|
|
429
|
+
return None
|
|
430
|
+
result.append(
|
|
431
|
+
{
|
|
432
|
+
"id": i,
|
|
433
|
+
"name": nvmlDeviceGetName(nvmlDeviceGetHandleByIndex(i)),
|
|
434
|
+
"total": info.total,
|
|
435
|
+
"free": info.free,
|
|
436
|
+
"used": info.used,
|
|
437
|
+
}
|
|
438
|
+
)
|
|
439
|
+
nvmlShutdown()
|
|
440
|
+
return result
|
|
441
|
+
except NVMLError as error:
|
|
442
|
+
print("Error fetching GPU information using nvml: %s", error)
|
|
443
|
+
return None
|
|
444
|
+
|
|
445
|
+
|
|
446
|
+
class MemoryMonitor(ABC):
|
|
447
|
+
def __init__(self, keep_measuring=True):
|
|
448
|
+
self.keep_measuring = keep_measuring
|
|
449
|
+
|
|
450
|
+
def measure_cpu_usage(self):
|
|
451
|
+
import psutil
|
|
452
|
+
|
|
453
|
+
max_usage = 0
|
|
454
|
+
while True:
|
|
455
|
+
max_usage = max(max_usage, psutil.Process(os.getpid()).memory_info().rss / 1024**2)
|
|
456
|
+
sleep(0.005) # 5ms
|
|
457
|
+
if not self.keep_measuring:
|
|
458
|
+
break
|
|
459
|
+
return max_usage
|
|
460
|
+
|
|
461
|
+
@abstractmethod
|
|
462
|
+
def measure_gpu_usage(self) -> Optional[List[Dict[str, Any]]]:
|
|
463
|
+
raise NotImplementedError()
|
|
464
|
+
|
|
465
|
+
|
|
466
|
+
class CudaMemoryMonitor(MemoryMonitor):
|
|
467
|
+
def __init__(self, keep_measuring=True):
|
|
468
|
+
super().__init__(keep_measuring)
|
|
469
|
+
|
|
470
|
+
def measure_gpu_usage(self) -> Optional[List[Dict[str, Any]]]:
|
|
471
|
+
from py3nvml.py3nvml import (
|
|
472
|
+
NVMLError,
|
|
473
|
+
nvmlDeviceGetCount,
|
|
474
|
+
nvmlDeviceGetHandleByIndex,
|
|
475
|
+
nvmlDeviceGetMemoryInfo,
|
|
476
|
+
nvmlDeviceGetName,
|
|
477
|
+
nvmlInit,
|
|
478
|
+
nvmlShutdown,
|
|
479
|
+
)
|
|
480
|
+
|
|
481
|
+
max_gpu_usage = []
|
|
482
|
+
gpu_name = []
|
|
483
|
+
try:
|
|
484
|
+
nvmlInit()
|
|
485
|
+
device_count = nvmlDeviceGetCount()
|
|
486
|
+
if not isinstance(device_count, int):
|
|
487
|
+
logger.error(f"nvmlDeviceGetCount result is not integer: {device_count}")
|
|
488
|
+
return None
|
|
489
|
+
|
|
490
|
+
max_gpu_usage = [0 for i in range(device_count)]
|
|
491
|
+
gpu_name = [nvmlDeviceGetName(nvmlDeviceGetHandleByIndex(i)) for i in range(device_count)]
|
|
492
|
+
while True:
|
|
493
|
+
for i in range(device_count):
|
|
494
|
+
info = nvmlDeviceGetMemoryInfo(nvmlDeviceGetHandleByIndex(i))
|
|
495
|
+
if isinstance(info, str):
|
|
496
|
+
logger.error(f"nvmlDeviceGetMemoryInfo returns str: {info}")
|
|
497
|
+
return None
|
|
498
|
+
max_gpu_usage[i] = max(max_gpu_usage[i], info.used / 1024**2)
|
|
499
|
+
sleep(0.005) # 5ms
|
|
500
|
+
if not self.keep_measuring:
|
|
501
|
+
break
|
|
502
|
+
nvmlShutdown()
|
|
503
|
+
return [
|
|
504
|
+
{
|
|
505
|
+
"device_id": i,
|
|
506
|
+
"name": gpu_name[i],
|
|
507
|
+
"max_used_MB": max_gpu_usage[i],
|
|
508
|
+
}
|
|
509
|
+
for i in range(device_count)
|
|
510
|
+
]
|
|
511
|
+
except NVMLError as error:
|
|
512
|
+
logger.error("Error fetching GPU information using nvml: %s", error)
|
|
513
|
+
return None
|
|
514
|
+
|
|
515
|
+
|
|
516
|
+
class RocmMemoryMonitor(MemoryMonitor):
|
|
517
|
+
def __init__(self, keep_measuring=True):
|
|
518
|
+
super().__init__(keep_measuring)
|
|
519
|
+
rocm_smi_path = "/opt/rocm/libexec/rocm_smi"
|
|
520
|
+
if os.path.exists(rocm_smi_path):
|
|
521
|
+
if rocm_smi_path not in sys.path:
|
|
522
|
+
sys.path.append(rocm_smi_path)
|
|
523
|
+
try:
|
|
524
|
+
import rocm_smi
|
|
525
|
+
|
|
526
|
+
self.rocm_smi = rocm_smi
|
|
527
|
+
self.rocm_smi.initializeRsmi()
|
|
528
|
+
except ImportError:
|
|
529
|
+
self.rocm_smi = None
|
|
530
|
+
|
|
531
|
+
def get_used_memory(self, dev):
|
|
532
|
+
if self.rocm_smi is None:
|
|
533
|
+
return -1
|
|
534
|
+
return self.rocm_smi.getMemInfo(dev, "VRAM")[0] / 1024 / 1024
|
|
535
|
+
|
|
536
|
+
def measure_gpu_usage(self):
|
|
537
|
+
if self.rocm_smi is None:
|
|
538
|
+
return None
|
|
539
|
+
|
|
540
|
+
device_count = len(self.rocm_smi.listDevices()) if self.rocm_smi is not None else 0
|
|
541
|
+
max_gpu_usage = [0 for i in range(device_count)]
|
|
542
|
+
gpu_name = [f"GPU{i}" for i in range(device_count)]
|
|
543
|
+
while True:
|
|
544
|
+
for i in range(device_count):
|
|
545
|
+
max_gpu_usage[i] = max(max_gpu_usage[i], self.get_used_memory(i))
|
|
546
|
+
time.sleep(0.005) # 5ms
|
|
547
|
+
if not self.keep_measuring:
|
|
548
|
+
break
|
|
549
|
+
return [
|
|
550
|
+
{
|
|
551
|
+
"device_id": i,
|
|
552
|
+
"name": gpu_name[i],
|
|
553
|
+
"max_used_MB": max_gpu_usage[i],
|
|
554
|
+
}
|
|
555
|
+
for i in range(device_count)
|
|
556
|
+
]
|
|
557
|
+
|
|
558
|
+
|
|
559
|
+
def measure_memory(is_gpu, func, monitor_type="cuda", start_memory=None):
|
|
560
|
+
memory_monitor_type = None
|
|
561
|
+
if monitor_type == "rocm":
|
|
562
|
+
memory_monitor_type = RocmMemoryMonitor
|
|
563
|
+
else:
|
|
564
|
+
memory_monitor_type = CudaMemoryMonitor
|
|
565
|
+
|
|
566
|
+
monitor = memory_monitor_type(False)
|
|
567
|
+
|
|
568
|
+
if is_gpu:
|
|
569
|
+
if start_memory is not None:
|
|
570
|
+
memory_before_test = start_memory
|
|
571
|
+
else:
|
|
572
|
+
memory_before_test = monitor.measure_gpu_usage()
|
|
573
|
+
if memory_before_test is None:
|
|
574
|
+
return None
|
|
575
|
+
|
|
576
|
+
if func is None:
|
|
577
|
+
return memory_before_test
|
|
578
|
+
|
|
579
|
+
with ThreadPoolExecutor() as executor:
|
|
580
|
+
monitor = memory_monitor_type()
|
|
581
|
+
mem_thread = executor.submit(monitor.measure_gpu_usage)
|
|
582
|
+
try:
|
|
583
|
+
fn_thread = executor.submit(func)
|
|
584
|
+
_ = fn_thread.result()
|
|
585
|
+
finally:
|
|
586
|
+
monitor.keep_measuring = False
|
|
587
|
+
max_usage = mem_thread.result()
|
|
588
|
+
|
|
589
|
+
if max_usage is None:
|
|
590
|
+
return None
|
|
591
|
+
|
|
592
|
+
logger.info(f"GPU memory usage: before={memory_before_test} peak={max_usage}")
|
|
593
|
+
if len(memory_before_test) >= 1 and len(max_usage) >= 1 and len(memory_before_test) == len(max_usage):
|
|
594
|
+
# When there are multiple GPUs, we will check the one with maximum usage.
|
|
595
|
+
max_used = 0
|
|
596
|
+
for i, memory_before in enumerate(memory_before_test):
|
|
597
|
+
before = memory_before["max_used_MB"]
|
|
598
|
+
after = max_usage[i]["max_used_MB"]
|
|
599
|
+
used = after - before
|
|
600
|
+
max_used = max(max_used, used)
|
|
601
|
+
return max_used
|
|
602
|
+
return None
|
|
603
|
+
|
|
604
|
+
# CPU memory
|
|
605
|
+
if start_memory is not None:
|
|
606
|
+
memory_before_test = start_memory
|
|
607
|
+
else:
|
|
608
|
+
memory_before_test = monitor.measure_cpu_usage()
|
|
609
|
+
|
|
610
|
+
if func is None:
|
|
611
|
+
return memory_before_test
|
|
612
|
+
|
|
613
|
+
with ThreadPoolExecutor() as executor:
|
|
614
|
+
monitor = memory_monitor_type()
|
|
615
|
+
mem_thread = executor.submit(monitor.measure_cpu_usage)
|
|
616
|
+
try:
|
|
617
|
+
fn_thread = executor.submit(func)
|
|
618
|
+
_ = fn_thread.result()
|
|
619
|
+
finally:
|
|
620
|
+
monitor.keep_measuring = False
|
|
621
|
+
max_usage = mem_thread.result()
|
|
622
|
+
|
|
623
|
+
logger.info(f"CPU memory usage: before={memory_before_test:.1f} MB, peak={max_usage:.1f} MB")
|
|
624
|
+
return max_usage - memory_before_test
|
|
625
|
+
|
|
626
|
+
|
|
627
|
+
def get_ort_environment_variables():
|
|
628
|
+
# Environment variables might impact ORT performance on transformer models. Note that they are for testing only.
|
|
629
|
+
env_names = [
|
|
630
|
+
"ORT_DISABLE_FUSED_ATTENTION",
|
|
631
|
+
"ORT_ENABLE_FUSED_CAUSAL_ATTENTION",
|
|
632
|
+
"ORT_DISABLE_FUSED_CROSS_ATTENTION",
|
|
633
|
+
"ORT_DISABLE_TRT_FLASH_ATTENTION",
|
|
634
|
+
"ORT_DISABLE_MEMORY_EFFICIENT_ATTENTION",
|
|
635
|
+
"ORT_TRANSFORMER_OPTIONS",
|
|
636
|
+
"ORT_CUDA_GEMM_OPTIONS",
|
|
637
|
+
]
|
|
638
|
+
env = ""
|
|
639
|
+
for name in env_names:
|
|
640
|
+
value = os.getenv(name)
|
|
641
|
+
if value is None:
|
|
642
|
+
continue
|
|
643
|
+
if env:
|
|
644
|
+
env += ","
|
|
645
|
+
env += f"{name}={value}"
|
|
646
|
+
return env
|