onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxruntime/LICENSE +21 -0
- onnxruntime/Privacy.md +21 -0
- onnxruntime/ThirdPartyNotices.txt +6121 -0
- onnxruntime/__init__.py +418 -0
- onnxruntime/backend/__init__.py +6 -0
- onnxruntime/backend/backend.py +175 -0
- onnxruntime/backend/backend_rep.py +52 -0
- onnxruntime/capi/DirectML.dll +0 -0
- onnxruntime/capi/__init__.py +4 -0
- onnxruntime/capi/_ld_preload.py +7 -0
- onnxruntime/capi/_pybind_state.py +33 -0
- onnxruntime/capi/build_and_package_info.py +2 -0
- onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
- onnxruntime/capi/onnxruntime.dll +0 -0
- onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
- onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
- onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
- onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
- onnxruntime/capi/onnxruntime_validation.py +154 -0
- onnxruntime/capi/version_info.py +2 -0
- onnxruntime/datasets/__init__.py +18 -0
- onnxruntime/datasets/logreg_iris.onnx +0 -0
- onnxruntime/datasets/mul_1.onnx +0 -0
- onnxruntime/datasets/sigmoid.onnx +13 -0
- onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
- onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
- onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
- onnxruntime/quantization/__init__.py +19 -0
- onnxruntime/quantization/base_quantizer.py +529 -0
- onnxruntime/quantization/calibrate.py +1267 -0
- onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
- onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
- onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
- onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
- onnxruntime/quantization/fusions/__init__.py +4 -0
- onnxruntime/quantization/fusions/fusion.py +311 -0
- onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
- onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
- onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
- onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
- onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
- onnxruntime/quantization/neural_compressor/__init__.py +1 -0
- onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
- onnxruntime/quantization/neural_compressor/util.py +80 -0
- onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
- onnxruntime/quantization/onnx_model.py +600 -0
- onnxruntime/quantization/onnx_quantizer.py +1163 -0
- onnxruntime/quantization/operators/__init__.py +2 -0
- onnxruntime/quantization/operators/activation.py +119 -0
- onnxruntime/quantization/operators/argmax.py +18 -0
- onnxruntime/quantization/operators/attention.py +73 -0
- onnxruntime/quantization/operators/base_operator.py +26 -0
- onnxruntime/quantization/operators/binary_op.py +72 -0
- onnxruntime/quantization/operators/concat.py +62 -0
- onnxruntime/quantization/operators/conv.py +260 -0
- onnxruntime/quantization/operators/direct_q8.py +78 -0
- onnxruntime/quantization/operators/embed_layernorm.py +121 -0
- onnxruntime/quantization/operators/gather.py +64 -0
- onnxruntime/quantization/operators/gavgpool.py +62 -0
- onnxruntime/quantization/operators/gemm.py +172 -0
- onnxruntime/quantization/operators/lstm.py +121 -0
- onnxruntime/quantization/operators/matmul.py +231 -0
- onnxruntime/quantization/operators/maxpool.py +34 -0
- onnxruntime/quantization/operators/norm.py +40 -0
- onnxruntime/quantization/operators/pad.py +172 -0
- onnxruntime/quantization/operators/pooling.py +67 -0
- onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
- onnxruntime/quantization/operators/resize.py +34 -0
- onnxruntime/quantization/operators/softmax.py +74 -0
- onnxruntime/quantization/operators/split.py +63 -0
- onnxruntime/quantization/operators/where.py +87 -0
- onnxruntime/quantization/preprocess.py +141 -0
- onnxruntime/quantization/qdq_loss_debug.py +389 -0
- onnxruntime/quantization/qdq_quantizer.py +1477 -0
- onnxruntime/quantization/quant_utils.py +1051 -0
- onnxruntime/quantization/quantize.py +953 -0
- onnxruntime/quantization/registry.py +110 -0
- onnxruntime/quantization/shape_inference.py +204 -0
- onnxruntime/quantization/static_quantize_runner.py +256 -0
- onnxruntime/quantization/tensor_quant_overrides.py +520 -0
- onnxruntime/tools/__init__.py +10 -0
- onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
- onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
- onnxruntime/tools/file_utils.py +47 -0
- onnxruntime/tools/logger.py +11 -0
- onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
- onnxruntime/tools/mobile_helpers/__init__.py +0 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
- onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
- onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
- onnxruntime/tools/offline_tuning.py +169 -0
- onnxruntime/tools/onnx_model_utils.py +416 -0
- onnxruntime/tools/onnx_randomizer.py +85 -0
- onnxruntime/tools/onnxruntime_test.py +164 -0
- onnxruntime/tools/optimize_onnx_model.py +56 -0
- onnxruntime/tools/ort_format_model/__init__.py +27 -0
- onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
- onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
- onnxruntime/tools/ort_format_model/types.py +85 -0
- onnxruntime/tools/ort_format_model/utils.py +61 -0
- onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
- onnxruntime/tools/pytorch_export_helpers.py +131 -0
- onnxruntime/tools/qdq_helpers/__init__.py +0 -0
- onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
- onnxruntime/tools/qnn/add_trans_cast.py +292 -0
- onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
- onnxruntime/tools/qnn/preprocess.py +165 -0
- onnxruntime/tools/reduced_build_config_parser.py +203 -0
- onnxruntime/tools/remove_initializer_from_input.py +37 -0
- onnxruntime/tools/symbolic_shape_infer.py +3094 -0
- onnxruntime/tools/update_onnx_opset.py +31 -0
- onnxruntime/transformers/__init__.py +8 -0
- onnxruntime/transformers/affinity_helper.py +40 -0
- onnxruntime/transformers/benchmark.py +942 -0
- onnxruntime/transformers/benchmark_helper.py +643 -0
- onnxruntime/transformers/bert_perf_test.py +629 -0
- onnxruntime/transformers/bert_test_data.py +641 -0
- onnxruntime/transformers/compare_bert_results.py +256 -0
- onnxruntime/transformers/constants.py +47 -0
- onnxruntime/transformers/convert_generation.py +3605 -0
- onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
- onnxruntime/transformers/convert_to_packing_mode.py +385 -0
- onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
- onnxruntime/transformers/float16.py +501 -0
- onnxruntime/transformers/fusion_attention.py +1189 -0
- onnxruntime/transformers/fusion_attention_clip.py +340 -0
- onnxruntime/transformers/fusion_attention_sam2.py +533 -0
- onnxruntime/transformers/fusion_attention_unet.py +1307 -0
- onnxruntime/transformers/fusion_attention_vae.py +300 -0
- onnxruntime/transformers/fusion_bart_attention.py +435 -0
- onnxruntime/transformers/fusion_base.py +141 -0
- onnxruntime/transformers/fusion_bias_add.py +57 -0
- onnxruntime/transformers/fusion_biasgelu.py +66 -0
- onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
- onnxruntime/transformers/fusion_conformer_attention.py +222 -0
- onnxruntime/transformers/fusion_constant_fold.py +144 -0
- onnxruntime/transformers/fusion_embedlayer.py +810 -0
- onnxruntime/transformers/fusion_fastgelu.py +492 -0
- onnxruntime/transformers/fusion_gelu.py +258 -0
- onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
- onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
- onnxruntime/transformers/fusion_gpt_attention.py +546 -0
- onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
- onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
- onnxruntime/transformers/fusion_group_norm.py +180 -0
- onnxruntime/transformers/fusion_layernorm.py +489 -0
- onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
- onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
- onnxruntime/transformers/fusion_options.py +340 -0
- onnxruntime/transformers/fusion_qordered_attention.py +420 -0
- onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
- onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
- onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
- onnxruntime/transformers/fusion_quickgelu.py +74 -0
- onnxruntime/transformers/fusion_reshape.py +173 -0
- onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
- onnxruntime/transformers/fusion_shape.py +109 -0
- onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
- onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
- onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
- onnxruntime/transformers/fusion_transpose.py +167 -0
- onnxruntime/transformers/fusion_utils.py +321 -0
- onnxruntime/transformers/huggingface_models.py +74 -0
- onnxruntime/transformers/import_utils.py +20 -0
- onnxruntime/transformers/io_binding_helper.py +487 -0
- onnxruntime/transformers/large_model_exporter.py +395 -0
- onnxruntime/transformers/machine_info.py +230 -0
- onnxruntime/transformers/metrics.py +163 -0
- onnxruntime/transformers/models/bart/__init__.py +12 -0
- onnxruntime/transformers/models/bart/export.py +98 -0
- onnxruntime/transformers/models/bert/__init__.py +12 -0
- onnxruntime/transformers/models/bert/eval_squad.py +329 -0
- onnxruntime/transformers/models/gpt2/__init__.py +12 -0
- onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
- onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
- onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
- onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
- onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
- onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
- onnxruntime/transformers/models/llama/__init__.py +12 -0
- onnxruntime/transformers/models/llama/benchmark.py +700 -0
- onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
- onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
- onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
- onnxruntime/transformers/models/llama/dist_settings.py +57 -0
- onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
- onnxruntime/transformers/models/llama/llama_parity.py +343 -0
- onnxruntime/transformers/models/llama/llama_torch.py +47 -0
- onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
- onnxruntime/transformers/models/longformer/__init__.py +12 -0
- onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
- onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
- onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
- onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
- onnxruntime/transformers/models/phi2/__init__.py +12 -0
- onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
- onnxruntime/transformers/models/phi2/inference_example.py +414 -0
- onnxruntime/transformers/models/sam2/__init__.py +12 -0
- onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
- onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
- onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
- onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
- onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
- onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
- onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
- onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
- onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
- onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
- onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
- onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
- onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
- onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
- onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
- onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
- onnxruntime/transformers/models/t5/__init__.py +12 -0
- onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
- onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
- onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
- onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
- onnxruntime/transformers/models/t5/t5_helper.py +302 -0
- onnxruntime/transformers/models/whisper/__init__.py +12 -0
- onnxruntime/transformers/models/whisper/benchmark.py +585 -0
- onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
- onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
- onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
- onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
- onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
- onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
- onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
- onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
- onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
- onnxruntime/transformers/onnx_exporter.py +719 -0
- onnxruntime/transformers/onnx_model.py +1636 -0
- onnxruntime/transformers/onnx_model_bart.py +141 -0
- onnxruntime/transformers/onnx_model_bert.py +488 -0
- onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
- onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
- onnxruntime/transformers/onnx_model_clip.py +42 -0
- onnxruntime/transformers/onnx_model_conformer.py +32 -0
- onnxruntime/transformers/onnx_model_gpt2.py +101 -0
- onnxruntime/transformers/onnx_model_mmdit.py +112 -0
- onnxruntime/transformers/onnx_model_phi.py +929 -0
- onnxruntime/transformers/onnx_model_sam2.py +137 -0
- onnxruntime/transformers/onnx_model_t5.py +985 -0
- onnxruntime/transformers/onnx_model_tnlr.py +226 -0
- onnxruntime/transformers/onnx_model_unet.py +258 -0
- onnxruntime/transformers/onnx_model_vae.py +42 -0
- onnxruntime/transformers/onnx_utils.py +55 -0
- onnxruntime/transformers/optimizer.py +620 -0
- onnxruntime/transformers/past_helper.py +149 -0
- onnxruntime/transformers/profile_result_processor.py +358 -0
- onnxruntime/transformers/profiler.py +434 -0
- onnxruntime/transformers/quantize_helper.py +76 -0
- onnxruntime/transformers/shape_infer_helper.py +121 -0
- onnxruntime/transformers/shape_optimizer.py +400 -0
- onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
- onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
- onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
- onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
- onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
- onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,149 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License. See License.txt in the project root for
|
|
4
|
+
# license information.
|
|
5
|
+
# --------------------------------------------------------------------------
|
|
6
|
+
|
|
7
|
+
import logging
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
|
|
11
|
+
logger = logging.getLogger(__name__)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class PastKeyValuesHelper:
|
|
15
|
+
"""Helper functions to process past key values for encoder-decoder model"""
|
|
16
|
+
|
|
17
|
+
@staticmethod
|
|
18
|
+
def get_past_names(num_layers, present: bool = False):
|
|
19
|
+
past_self_names = []
|
|
20
|
+
past_cross_names = []
|
|
21
|
+
for i in range(num_layers):
|
|
22
|
+
past_self_names.extend(
|
|
23
|
+
[f"present_key_self_{i}", f"present_value_self_{i}"]
|
|
24
|
+
if present
|
|
25
|
+
else [f"past_key_self_{i}", f"past_value_self_{i}"]
|
|
26
|
+
)
|
|
27
|
+
past_cross_names.extend(
|
|
28
|
+
[f"present_key_cross_{i}", f"present_value_cross_{i}"]
|
|
29
|
+
if present
|
|
30
|
+
else [f"past_key_cross_{i}", f"past_value_cross_{i}"]
|
|
31
|
+
)
|
|
32
|
+
return past_self_names + past_cross_names
|
|
33
|
+
|
|
34
|
+
@staticmethod
|
|
35
|
+
def group_by_self_or_cross(present_key_values):
|
|
36
|
+
"""Split present state from grouped by layer to grouped by self/cross attention.
|
|
37
|
+
Before: (past_key_self_0, past_value_self_0, past_key_cross_0, past_value_cross_0), (past_key_self_1, past_value_self_1, past_key_cross_1, past_value_cross_1), ...
|
|
38
|
+
After: (past_key_self_0, past_value_self_0, past_key_self_1, past_value_self_1, ...), (past_key_cross_0, past_value_cross_0, past_key_cross_1, past_value_cross_1, ...)
|
|
39
|
+
|
|
40
|
+
"""
|
|
41
|
+
present_self = []
|
|
42
|
+
present_cross = []
|
|
43
|
+
for _i, present_layer_i in enumerate(present_key_values):
|
|
44
|
+
assert len(present_layer_i) == 4, f"Expected to have four items. Got {len(present_layer_i)}"
|
|
45
|
+
(
|
|
46
|
+
present_key_self,
|
|
47
|
+
present_value_self,
|
|
48
|
+
present_key_cross,
|
|
49
|
+
present_value_cross,
|
|
50
|
+
) = present_layer_i
|
|
51
|
+
present_self.extend([present_key_self, present_value_self])
|
|
52
|
+
present_cross.extend([present_key_cross, present_value_cross])
|
|
53
|
+
return present_self, present_cross
|
|
54
|
+
|
|
55
|
+
@staticmethod
|
|
56
|
+
def group_by_layer(past, num_layers):
|
|
57
|
+
"""Reorder past state from grouped by self/cross attention to grouped by layer.
|
|
58
|
+
Before: past_key_self_0, past_value_self_0, past_key_self_1, past_value_self_1, ..., past_key_cross_0, past_value_cross_0, past_key_cross_1, past_value_cross_1, ...
|
|
59
|
+
After: (past_key_self_0, past_value_self_0, past_key_cross_0, past_value_cross_0), (past_key_self_1, past_value_self_1, past_key_cross_1, past_value_cross_1),
|
|
60
|
+
"""
|
|
61
|
+
assert len(past) == 4 * num_layers
|
|
62
|
+
return tuple(
|
|
63
|
+
[
|
|
64
|
+
past[2 * i],
|
|
65
|
+
past[2 * i + 1],
|
|
66
|
+
past[2 * num_layers + 2 * i],
|
|
67
|
+
past[2 * num_layers + 2 * i + 1],
|
|
68
|
+
]
|
|
69
|
+
for i in range(num_layers)
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
@staticmethod
|
|
73
|
+
def back_group_by_layer(past_key_values: tuple[tuple[torch.Tensor]]):
|
|
74
|
+
"""Categorize present_key_values from self and cross attention to layer by layer.
|
|
75
|
+
|
|
76
|
+
Reorder past state from grouped by self/cross attention to grouped by layer.
|
|
77
|
+
Before: past_key_self_0, past_value_self_0, past_key_self_1, past_value_self_1, ...,
|
|
78
|
+
past_key_cross_0, past_value_cross_0, past_key_cross_1, past_value_cross_1, ...
|
|
79
|
+
After: (past_key_self_0, past_value_self_0, past_key_cross_0, past_value_cross_0),
|
|
80
|
+
(past_key_self_1, past_value_self_1, past_key_cross_1, past_value_cross_1),
|
|
81
|
+
|
|
82
|
+
Args:
|
|
83
|
+
present_key_values: From past_key_values of a model (group by self and cross attention)
|
|
84
|
+
|
|
85
|
+
Returns:
|
|
86
|
+
past_tuples: present key and values grouped by layer.
|
|
87
|
+
"""
|
|
88
|
+
past_tuples = ()
|
|
89
|
+
half_idx = len(past_key_values) // 2
|
|
90
|
+
for i in range(len(past_key_values) // 4):
|
|
91
|
+
idx = 2 * i
|
|
92
|
+
past_tuples += (
|
|
93
|
+
(
|
|
94
|
+
past_key_values[idx],
|
|
95
|
+
past_key_values[idx + 1],
|
|
96
|
+
past_key_values[half_idx + idx],
|
|
97
|
+
past_key_values[half_idx + idx + 1],
|
|
98
|
+
),
|
|
99
|
+
)
|
|
100
|
+
return past_tuples
|
|
101
|
+
|
|
102
|
+
@staticmethod
|
|
103
|
+
def group_by_self_and_cross(present_key_values: tuple[torch.Tensor], concat: bool = False):
|
|
104
|
+
"""Categorize present_key_values into self and cross attention.
|
|
105
|
+
|
|
106
|
+
Split present state from grouped by layer to grouped by self/cross attention.
|
|
107
|
+
Before: (past_key_self_0, past_value_self_0, past_key_cross_0, past_value_cross_0),
|
|
108
|
+
(past_key_self_1, past_value_self_1, past_key_cross_1, past_value_cross_1), ...
|
|
109
|
+
After: (past_key_self_0, past_value_self_0, past_key_self_1, past_value_self_1, ...),
|
|
110
|
+
(past_key_cross_0, past_value_cross_0, past_key_cross_1, past_value_cross_1, ...)
|
|
111
|
+
|
|
112
|
+
Args:
|
|
113
|
+
present_key_values: From past_key_values of a model (group by layer)
|
|
114
|
+
concat: If concat self attention with cross attention key/value to return
|
|
115
|
+
|
|
116
|
+
Returns:
|
|
117
|
+
present_self (Tuple[torch.Tensor]): present key and values from self attention
|
|
118
|
+
present_cross (Tuple[torch.Tensor]): present key and values from cross attention
|
|
119
|
+
"""
|
|
120
|
+
present_self: list[torch.Tensor] = []
|
|
121
|
+
present_cross: list[torch.Tensor] = []
|
|
122
|
+
for _, present_layer_i in enumerate(present_key_values):
|
|
123
|
+
assert len(present_layer_i) == 4, f"Expected to have four items. Got {len(present_layer_i)}"
|
|
124
|
+
present_key_self, present_value_self, present_key_cross, present_value_cross = present_layer_i
|
|
125
|
+
present_self.extend([present_key_self, present_value_self])
|
|
126
|
+
present_cross.extend([present_key_cross, present_value_cross])
|
|
127
|
+
if concat:
|
|
128
|
+
return present_self + present_cross
|
|
129
|
+
else:
|
|
130
|
+
return present_self, present_cross
|
|
131
|
+
|
|
132
|
+
@staticmethod
|
|
133
|
+
def get_input_names(past_key_values: tuple[tuple[torch.Tensor]], encoder=True):
|
|
134
|
+
"""Process input names of model wrapper.
|
|
135
|
+
|
|
136
|
+
Args:
|
|
137
|
+
past_key_values: Consider `self` and `cross` past_key_values
|
|
138
|
+
|
|
139
|
+
Returns:
|
|
140
|
+
names (List[string]): input names
|
|
141
|
+
"""
|
|
142
|
+
names = []
|
|
143
|
+
num_layers = len(past_key_values) // 4 if encoder else len(past_key_values)
|
|
144
|
+
prefix = "past_" if not encoder else "present_"
|
|
145
|
+
for i in range(num_layers):
|
|
146
|
+
names.extend([prefix + s for s in [f"key_self_{i}", f"value_self_{i}"]])
|
|
147
|
+
for i in range(num_layers):
|
|
148
|
+
names.extend([prefix + s for s in [f"key_cross_{i}", f"value_cross_{i}"]])
|
|
149
|
+
return names
|
|
@@ -0,0 +1,358 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License.
|
|
4
|
+
# --------------------------------------------------------------------------
|
|
5
|
+
|
|
6
|
+
"""This profiler result processor print out the kernel time spent on each Node of the model.
|
|
7
|
+
Example of importing profile result file from onnxruntime_perf_test:
|
|
8
|
+
python profile_result_processor.py --input profile_2021-10-25_12-02-41.json
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
import argparse
|
|
12
|
+
import json
|
|
13
|
+
|
|
14
|
+
_NODES_TYPE_CONTAINING_SUBGRAPH = frozenset(("Scan", "Loop", "If"))
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def parse_arguments(argv=None):
|
|
18
|
+
parser = argparse.ArgumentParser()
|
|
19
|
+
|
|
20
|
+
parser.add_argument(
|
|
21
|
+
"-i",
|
|
22
|
+
"--input",
|
|
23
|
+
required=False,
|
|
24
|
+
type=str,
|
|
25
|
+
help="Set the input file for reading the profile results",
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
parser.add_argument(
|
|
29
|
+
"--threshold",
|
|
30
|
+
required=False,
|
|
31
|
+
type=float,
|
|
32
|
+
default=0.01,
|
|
33
|
+
help="Threshold of run time ratio among all nodes. Nodes with larger ratio will show in top expensive nodes.",
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
parser.add_argument(
|
|
37
|
+
"--provider",
|
|
38
|
+
required=False,
|
|
39
|
+
type=str,
|
|
40
|
+
default="cuda",
|
|
41
|
+
help="Execution provider to use",
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
parser.add_argument(
|
|
45
|
+
"--kernel_time_only",
|
|
46
|
+
required=False,
|
|
47
|
+
action="store_true",
|
|
48
|
+
help="Only include the kernel time and no fence time",
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
parser.set_defaults(kernel_time_only=False)
|
|
52
|
+
|
|
53
|
+
parser.add_argument("-v", "--verbose", required=False, action="store_true")
|
|
54
|
+
parser.set_defaults(verbose=False)
|
|
55
|
+
|
|
56
|
+
return parser.parse_args(argv)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def load_profile_json(profile_file):
|
|
60
|
+
print(f"loading profile output {profile_file} ...")
|
|
61
|
+
|
|
62
|
+
with open(profile_file) as opened_file:
|
|
63
|
+
sess_time = json.load(opened_file)
|
|
64
|
+
|
|
65
|
+
assert isinstance(sess_time, list)
|
|
66
|
+
return sess_time
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def parse_kernel_results(sess_time, threshold=0):
|
|
70
|
+
"""Parse profile data and output nodes in two sections - nodes in the original order, and top expensive nodes.
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
sess_time (List[Dict]): profile data
|
|
74
|
+
threshold (int, optional): Minimum ratio of duration among all. Defaults to 0.
|
|
75
|
+
|
|
76
|
+
Returns:
|
|
77
|
+
List[str]: lines of string for output.
|
|
78
|
+
"""
|
|
79
|
+
kernel_name_to_op_name = {}
|
|
80
|
+
kernel_time = {}
|
|
81
|
+
kernel_freq = {}
|
|
82
|
+
total = 0
|
|
83
|
+
session_init = False
|
|
84
|
+
for item in sess_time:
|
|
85
|
+
# Skip all MemcpyHostToDevice before session_initialization
|
|
86
|
+
if item["cat"] == "Session" and item["name"] == "session_initialization":
|
|
87
|
+
session_init = True
|
|
88
|
+
if not session_init:
|
|
89
|
+
continue
|
|
90
|
+
|
|
91
|
+
if item["cat"] == "Kernel" and "dur" in item and "args" in item and "op_name" in item["args"]:
|
|
92
|
+
kernel_name = item["name"]
|
|
93
|
+
|
|
94
|
+
op_name = item["args"]["op_name"]
|
|
95
|
+
if op_name in _NODES_TYPE_CONTAINING_SUBGRAPH:
|
|
96
|
+
continue
|
|
97
|
+
|
|
98
|
+
# Handle MemcpyHostToDevice and MemcpyDeviceToHost here
|
|
99
|
+
if not op_name:
|
|
100
|
+
op_name = f"({kernel_name})"
|
|
101
|
+
|
|
102
|
+
if kernel_name in kernel_time:
|
|
103
|
+
kernel_time[kernel_name] += item["dur"]
|
|
104
|
+
kernel_freq[kernel_name] += 1
|
|
105
|
+
else:
|
|
106
|
+
kernel_time[kernel_name] = item["dur"]
|
|
107
|
+
kernel_freq[kernel_name] = 1
|
|
108
|
+
kernel_name_to_op_name[kernel_name] = op_name
|
|
109
|
+
|
|
110
|
+
total += item["dur"]
|
|
111
|
+
|
|
112
|
+
if not kernel_time:
|
|
113
|
+
return ["No kernel record found!"]
|
|
114
|
+
|
|
115
|
+
# Output items with run time ratio > thresholds, and sorted by duration in the descending order.
|
|
116
|
+
lines = []
|
|
117
|
+
lines.append(f"\nTop expensive kernels with Time% >= {threshold * 100:.2f}:")
|
|
118
|
+
lines.append("-" * 64)
|
|
119
|
+
lines.append("Total(μs)\tTime%\tCalls\tAvg(μs)\tKernel")
|
|
120
|
+
for kernel_name, duration in sorted(kernel_time.items(), key=lambda x: x[1], reverse=True):
|
|
121
|
+
ratio = duration / total
|
|
122
|
+
if ratio < threshold:
|
|
123
|
+
continue
|
|
124
|
+
|
|
125
|
+
calls = kernel_freq[kernel_name]
|
|
126
|
+
avg_time = duration / float(calls)
|
|
127
|
+
lines.append(f"{duration:10d}\t{ratio * 100.0:5.2f}\t{calls:5d}\t{avg_time:8.1f}\t{kernel_name}")
|
|
128
|
+
|
|
129
|
+
# Group by operator
|
|
130
|
+
op_time = {}
|
|
131
|
+
for kernel_name, op_name in kernel_name_to_op_name.items():
|
|
132
|
+
duration = kernel_time[kernel_name]
|
|
133
|
+
if op_name in op_time:
|
|
134
|
+
op_time[op_name] += duration
|
|
135
|
+
else:
|
|
136
|
+
op_time[op_name] = duration
|
|
137
|
+
|
|
138
|
+
lines.append("\nGroup kernel time by operator:")
|
|
139
|
+
lines.append("-" * 64)
|
|
140
|
+
lines.append("Total(μs)\tTime%\tOperator")
|
|
141
|
+
for op_name, duration in sorted(op_time.items(), key=lambda x: x[1], reverse=True):
|
|
142
|
+
ratio = duration / total
|
|
143
|
+
lines.append(f"{duration:10d}\t{ratio * 100.0:5.2f}\t{op_name}")
|
|
144
|
+
|
|
145
|
+
return lines
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def parse_node_results(sess_time, kernel_time_only=False, threshold=0):
|
|
149
|
+
"""Parse profile data and output nodes in two sections - nodes in the original order, and top expensive nodes.
|
|
150
|
+
|
|
151
|
+
Args:
|
|
152
|
+
sess_time (List[Dict]): profile data
|
|
153
|
+
kernel_time_only (bool, optional): Only include items for kernel time. Defaults to False.
|
|
154
|
+
threshold (int, optional): Minimum ratio of duration among all. Defaults to 0.
|
|
155
|
+
|
|
156
|
+
Returns:
|
|
157
|
+
List[str]: lines of string for output.
|
|
158
|
+
"""
|
|
159
|
+
node_name_list = []
|
|
160
|
+
node_time = {}
|
|
161
|
+
node_freq = {}
|
|
162
|
+
node_provider = {}
|
|
163
|
+
total = 0
|
|
164
|
+
for item in sess_time:
|
|
165
|
+
if item["cat"] == "Node" and "dur" in item and "args" in item and "op_name" in item["args"]:
|
|
166
|
+
node_name = (
|
|
167
|
+
item["name"].replace("_kernel_time", "").replace("_fence_before", "").replace("_fence_after", "")
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
if "provider" in item["args"]:
|
|
171
|
+
if item["args"]["provider"] == "CPUExecutionProvider":
|
|
172
|
+
device = "CPU"
|
|
173
|
+
elif item["args"]["provider"] == "CUDAExecutionProvider":
|
|
174
|
+
device = "CUDA"
|
|
175
|
+
elif item["args"]["provider"] == "DmlExecutionProvider":
|
|
176
|
+
device = "DML"
|
|
177
|
+
|
|
178
|
+
if node_name not in node_provider:
|
|
179
|
+
node_provider[node_name] = device
|
|
180
|
+
else:
|
|
181
|
+
assert node_provider[node_name] == device
|
|
182
|
+
elif kernel_time_only:
|
|
183
|
+
continue
|
|
184
|
+
|
|
185
|
+
op_name = item["args"]["op_name"]
|
|
186
|
+
if op_name in _NODES_TYPE_CONTAINING_SUBGRAPH:
|
|
187
|
+
continue
|
|
188
|
+
|
|
189
|
+
if node_name in node_time:
|
|
190
|
+
node_time[node_name] += item["dur"]
|
|
191
|
+
node_freq[node_name] += 1
|
|
192
|
+
else:
|
|
193
|
+
node_time[node_name] = item["dur"]
|
|
194
|
+
node_freq[node_name] = 1
|
|
195
|
+
node_name_list.append(node_name)
|
|
196
|
+
|
|
197
|
+
total += item["dur"]
|
|
198
|
+
|
|
199
|
+
# Output items in the original order.
|
|
200
|
+
lines = [
|
|
201
|
+
"\nNodes in the original order:",
|
|
202
|
+
"-" * 64,
|
|
203
|
+
"Total(μs)\tTime%\tAcc %\tAvg(μs)\tCalls\tProvider\tNode",
|
|
204
|
+
]
|
|
205
|
+
before_percentage = 0.0
|
|
206
|
+
for node_name in node_name_list:
|
|
207
|
+
duration = node_time[node_name]
|
|
208
|
+
calls = node_freq[node_name]
|
|
209
|
+
avg_time = duration / float(calls)
|
|
210
|
+
percentage = (duration / total) * 100.0
|
|
211
|
+
provider = node_provider.get(node_name, "")
|
|
212
|
+
before_percentage += percentage
|
|
213
|
+
lines.append(
|
|
214
|
+
f"{duration:10d}\t{percentage:5.2f}\t{before_percentage:5.2f}\t{avg_time:8.1f}\t{calls:5d}\t{provider:8s}\t{node_name}"
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
# Output items with run time ratio > thresholds, and sorted by duration in the descending order.
|
|
218
|
+
lines.append(f"\nTop expensive nodes with Time% >= {threshold * 100:.2f}:")
|
|
219
|
+
lines.append("-" * 64)
|
|
220
|
+
lines.append("Total(μs)\tTime%\tAvg(μs)\tCalls\tProvider\tNode")
|
|
221
|
+
for node_name, duration in sorted(node_time.items(), key=lambda x: x[1], reverse=True):
|
|
222
|
+
ratio = duration / total
|
|
223
|
+
if ratio < threshold:
|
|
224
|
+
continue
|
|
225
|
+
|
|
226
|
+
calls = node_freq[node_name]
|
|
227
|
+
avg_time = duration / float(calls)
|
|
228
|
+
percentage = (duration / total) * 100.0
|
|
229
|
+
provider = node_provider.get(node_name, "")
|
|
230
|
+
lines.append(f"{duration:10d}\t{percentage:5.2f}\t{avg_time:8.1f}\t{calls:5d}\t{provider:8s}\t{node_name}")
|
|
231
|
+
|
|
232
|
+
return lines
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
def group_node_results(sess_time):
|
|
236
|
+
"""Group results by operator name.
|
|
237
|
+
|
|
238
|
+
Args:
|
|
239
|
+
sess_time (List[Dict]): profile data
|
|
240
|
+
|
|
241
|
+
Returns:
|
|
242
|
+
List[str]: lines of string for output.
|
|
243
|
+
"""
|
|
244
|
+
op_kernel_time = {}
|
|
245
|
+
op_kernel_records = {}
|
|
246
|
+
total_kernel_time = 0
|
|
247
|
+
|
|
248
|
+
provider_op_kernel_time = {}
|
|
249
|
+
provider_op_kernel_records = {}
|
|
250
|
+
provider_kernel_time = {}
|
|
251
|
+
|
|
252
|
+
op_fence_time = {}
|
|
253
|
+
total_fence_time = 0
|
|
254
|
+
|
|
255
|
+
provider_counter = {}
|
|
256
|
+
for item in sess_time:
|
|
257
|
+
if item["cat"] == "Node" and "dur" in item and "args" in item and "op_name" in item["args"]:
|
|
258
|
+
op_name = item["args"]["op_name"]
|
|
259
|
+
|
|
260
|
+
# TODO: shall we have a separated group for nodes with subgraph?
|
|
261
|
+
if op_name in _NODES_TYPE_CONTAINING_SUBGRAPH:
|
|
262
|
+
continue
|
|
263
|
+
|
|
264
|
+
if "provider" not in item["args"]:
|
|
265
|
+
if "fence" in item["name"]:
|
|
266
|
+
if op_name in op_fence_time:
|
|
267
|
+
op_fence_time[op_name] += item["dur"]
|
|
268
|
+
else:
|
|
269
|
+
op_fence_time[op_name] = item["dur"]
|
|
270
|
+
total_fence_time += item["dur"]
|
|
271
|
+
continue
|
|
272
|
+
|
|
273
|
+
provider = item["args"].get("provider", "")
|
|
274
|
+
if provider in provider_counter:
|
|
275
|
+
provider_counter[provider] += 1
|
|
276
|
+
else:
|
|
277
|
+
provider_counter[provider] = 1
|
|
278
|
+
|
|
279
|
+
key = f"{provider}:{op_name}"
|
|
280
|
+
if key in provider_op_kernel_time:
|
|
281
|
+
provider_op_kernel_time[key] += item["dur"]
|
|
282
|
+
provider_op_kernel_records[key] += 1
|
|
283
|
+
else:
|
|
284
|
+
provider_op_kernel_time[key] = item["dur"]
|
|
285
|
+
provider_op_kernel_records[key] = 1
|
|
286
|
+
|
|
287
|
+
if provider in provider_kernel_time:
|
|
288
|
+
provider_kernel_time[provider] += item["dur"]
|
|
289
|
+
else:
|
|
290
|
+
provider_kernel_time[provider] = item["dur"]
|
|
291
|
+
|
|
292
|
+
if op_name in op_kernel_time:
|
|
293
|
+
op_kernel_time[op_name] += item["dur"]
|
|
294
|
+
op_kernel_records[op_name] += 1
|
|
295
|
+
else:
|
|
296
|
+
op_kernel_time[op_name] = item["dur"]
|
|
297
|
+
op_kernel_records[op_name] = 1
|
|
298
|
+
|
|
299
|
+
total_kernel_time += item["dur"]
|
|
300
|
+
|
|
301
|
+
lines = ["", "Grouped by operator"]
|
|
302
|
+
lines.append("-" * 64)
|
|
303
|
+
lines.append("Total(μs)\tTime%\tKernel(μs)\tKernel%\tCalls\tAvgKernel(μs)\tFence(μs)\tOperator")
|
|
304
|
+
for op_name, kernel_time in sorted(op_kernel_time.items(), key=lambda x: x[1], reverse=True):
|
|
305
|
+
fence_time = op_fence_time.get(op_name, 0)
|
|
306
|
+
kernel_time_ratio = kernel_time / total_kernel_time
|
|
307
|
+
total_time = kernel_time + fence_time
|
|
308
|
+
time_ratio = total_time / (total_kernel_time + total_fence_time)
|
|
309
|
+
kernel_calls = op_kernel_records[op_name]
|
|
310
|
+
avg_kernel_time = kernel_time / kernel_calls
|
|
311
|
+
lines.append(
|
|
312
|
+
f"{total_time:10d}\t{time_ratio * 100.0:5.2f}\t{kernel_time:11d}\t{kernel_time_ratio * 100.0:5.2f}\t{kernel_calls:5d}\t{avg_kernel_time:14.1f}\t{fence_time:10d}\t{op_name}"
|
|
313
|
+
)
|
|
314
|
+
|
|
315
|
+
lines += ["", "Grouped by provider + operator"]
|
|
316
|
+
lines.append("-" * 64)
|
|
317
|
+
lines.append("Kernel(μs)\tProvider%\tCalls\tAvgKernel(μs)\tProvider\tOperator")
|
|
318
|
+
for key, kernel_time in sorted(provider_op_kernel_time.items(), key=lambda x: x[1], reverse=True):
|
|
319
|
+
parts = key.split(":")
|
|
320
|
+
provider = parts[0]
|
|
321
|
+
op_name = parts[1]
|
|
322
|
+
short_ep = provider.replace("ExecutionProvider", "")
|
|
323
|
+
calls = provider_op_kernel_records[key]
|
|
324
|
+
avg_kernel_time = kernel_time / calls
|
|
325
|
+
provider_time_ratio = kernel_time / provider_kernel_time[provider]
|
|
326
|
+
lines.append(
|
|
327
|
+
f"{kernel_time:10d}\t{provider_time_ratio * 100.0:9.2f}\t{calls:5d}\t{avg_kernel_time:14.1f}\t{short_ep:8s}\t{op_name}"
|
|
328
|
+
)
|
|
329
|
+
|
|
330
|
+
return lines
|
|
331
|
+
|
|
332
|
+
|
|
333
|
+
def process_results(profile_file, args):
|
|
334
|
+
profile_records = load_profile_json(profile_file)
|
|
335
|
+
|
|
336
|
+
lines = parse_kernel_results(profile_records, args.threshold)
|
|
337
|
+
|
|
338
|
+
lines += parse_node_results(profile_records, args.kernel_time_only, args.threshold)
|
|
339
|
+
|
|
340
|
+
lines += group_node_results(profile_records)
|
|
341
|
+
|
|
342
|
+
return lines
|
|
343
|
+
|
|
344
|
+
|
|
345
|
+
if __name__ == "__main__":
|
|
346
|
+
arguments = parse_arguments()
|
|
347
|
+
print("Arguments", arguments)
|
|
348
|
+
|
|
349
|
+
from benchmark_helper import setup_logger
|
|
350
|
+
|
|
351
|
+
setup_logger(arguments.verbose)
|
|
352
|
+
|
|
353
|
+
profile_file = arguments.input
|
|
354
|
+
|
|
355
|
+
results = process_results(profile_file, arguments)
|
|
356
|
+
|
|
357
|
+
for line in results:
|
|
358
|
+
print(line)
|