onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxruntime/LICENSE +21 -0
- onnxruntime/Privacy.md +21 -0
- onnxruntime/ThirdPartyNotices.txt +6121 -0
- onnxruntime/__init__.py +418 -0
- onnxruntime/backend/__init__.py +6 -0
- onnxruntime/backend/backend.py +175 -0
- onnxruntime/backend/backend_rep.py +52 -0
- onnxruntime/capi/DirectML.dll +0 -0
- onnxruntime/capi/__init__.py +4 -0
- onnxruntime/capi/_ld_preload.py +7 -0
- onnxruntime/capi/_pybind_state.py +33 -0
- onnxruntime/capi/build_and_package_info.py +2 -0
- onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
- onnxruntime/capi/onnxruntime.dll +0 -0
- onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
- onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
- onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
- onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
- onnxruntime/capi/onnxruntime_validation.py +154 -0
- onnxruntime/capi/version_info.py +2 -0
- onnxruntime/datasets/__init__.py +18 -0
- onnxruntime/datasets/logreg_iris.onnx +0 -0
- onnxruntime/datasets/mul_1.onnx +0 -0
- onnxruntime/datasets/sigmoid.onnx +13 -0
- onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
- onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
- onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
- onnxruntime/quantization/__init__.py +19 -0
- onnxruntime/quantization/base_quantizer.py +529 -0
- onnxruntime/quantization/calibrate.py +1267 -0
- onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
- onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
- onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
- onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
- onnxruntime/quantization/fusions/__init__.py +4 -0
- onnxruntime/quantization/fusions/fusion.py +311 -0
- onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
- onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
- onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
- onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
- onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
- onnxruntime/quantization/neural_compressor/__init__.py +1 -0
- onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
- onnxruntime/quantization/neural_compressor/util.py +80 -0
- onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
- onnxruntime/quantization/onnx_model.py +600 -0
- onnxruntime/quantization/onnx_quantizer.py +1163 -0
- onnxruntime/quantization/operators/__init__.py +2 -0
- onnxruntime/quantization/operators/activation.py +119 -0
- onnxruntime/quantization/operators/argmax.py +18 -0
- onnxruntime/quantization/operators/attention.py +73 -0
- onnxruntime/quantization/operators/base_operator.py +26 -0
- onnxruntime/quantization/operators/binary_op.py +72 -0
- onnxruntime/quantization/operators/concat.py +62 -0
- onnxruntime/quantization/operators/conv.py +260 -0
- onnxruntime/quantization/operators/direct_q8.py +78 -0
- onnxruntime/quantization/operators/embed_layernorm.py +121 -0
- onnxruntime/quantization/operators/gather.py +64 -0
- onnxruntime/quantization/operators/gavgpool.py +62 -0
- onnxruntime/quantization/operators/gemm.py +172 -0
- onnxruntime/quantization/operators/lstm.py +121 -0
- onnxruntime/quantization/operators/matmul.py +231 -0
- onnxruntime/quantization/operators/maxpool.py +34 -0
- onnxruntime/quantization/operators/norm.py +40 -0
- onnxruntime/quantization/operators/pad.py +172 -0
- onnxruntime/quantization/operators/pooling.py +67 -0
- onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
- onnxruntime/quantization/operators/resize.py +34 -0
- onnxruntime/quantization/operators/softmax.py +74 -0
- onnxruntime/quantization/operators/split.py +63 -0
- onnxruntime/quantization/operators/where.py +87 -0
- onnxruntime/quantization/preprocess.py +141 -0
- onnxruntime/quantization/qdq_loss_debug.py +389 -0
- onnxruntime/quantization/qdq_quantizer.py +1477 -0
- onnxruntime/quantization/quant_utils.py +1051 -0
- onnxruntime/quantization/quantize.py +953 -0
- onnxruntime/quantization/registry.py +110 -0
- onnxruntime/quantization/shape_inference.py +204 -0
- onnxruntime/quantization/static_quantize_runner.py +256 -0
- onnxruntime/quantization/tensor_quant_overrides.py +520 -0
- onnxruntime/tools/__init__.py +10 -0
- onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
- onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
- onnxruntime/tools/file_utils.py +47 -0
- onnxruntime/tools/logger.py +11 -0
- onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
- onnxruntime/tools/mobile_helpers/__init__.py +0 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
- onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
- onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
- onnxruntime/tools/offline_tuning.py +169 -0
- onnxruntime/tools/onnx_model_utils.py +416 -0
- onnxruntime/tools/onnx_randomizer.py +85 -0
- onnxruntime/tools/onnxruntime_test.py +164 -0
- onnxruntime/tools/optimize_onnx_model.py +56 -0
- onnxruntime/tools/ort_format_model/__init__.py +27 -0
- onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
- onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
- onnxruntime/tools/ort_format_model/types.py +85 -0
- onnxruntime/tools/ort_format_model/utils.py +61 -0
- onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
- onnxruntime/tools/pytorch_export_helpers.py +131 -0
- onnxruntime/tools/qdq_helpers/__init__.py +0 -0
- onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
- onnxruntime/tools/qnn/add_trans_cast.py +292 -0
- onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
- onnxruntime/tools/qnn/preprocess.py +165 -0
- onnxruntime/tools/reduced_build_config_parser.py +203 -0
- onnxruntime/tools/remove_initializer_from_input.py +37 -0
- onnxruntime/tools/symbolic_shape_infer.py +3094 -0
- onnxruntime/tools/update_onnx_opset.py +31 -0
- onnxruntime/transformers/__init__.py +8 -0
- onnxruntime/transformers/affinity_helper.py +40 -0
- onnxruntime/transformers/benchmark.py +942 -0
- onnxruntime/transformers/benchmark_helper.py +643 -0
- onnxruntime/transformers/bert_perf_test.py +629 -0
- onnxruntime/transformers/bert_test_data.py +641 -0
- onnxruntime/transformers/compare_bert_results.py +256 -0
- onnxruntime/transformers/constants.py +47 -0
- onnxruntime/transformers/convert_generation.py +3605 -0
- onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
- onnxruntime/transformers/convert_to_packing_mode.py +385 -0
- onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
- onnxruntime/transformers/float16.py +501 -0
- onnxruntime/transformers/fusion_attention.py +1189 -0
- onnxruntime/transformers/fusion_attention_clip.py +340 -0
- onnxruntime/transformers/fusion_attention_sam2.py +533 -0
- onnxruntime/transformers/fusion_attention_unet.py +1307 -0
- onnxruntime/transformers/fusion_attention_vae.py +300 -0
- onnxruntime/transformers/fusion_bart_attention.py +435 -0
- onnxruntime/transformers/fusion_base.py +141 -0
- onnxruntime/transformers/fusion_bias_add.py +57 -0
- onnxruntime/transformers/fusion_biasgelu.py +66 -0
- onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
- onnxruntime/transformers/fusion_conformer_attention.py +222 -0
- onnxruntime/transformers/fusion_constant_fold.py +144 -0
- onnxruntime/transformers/fusion_embedlayer.py +810 -0
- onnxruntime/transformers/fusion_fastgelu.py +492 -0
- onnxruntime/transformers/fusion_gelu.py +258 -0
- onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
- onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
- onnxruntime/transformers/fusion_gpt_attention.py +546 -0
- onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
- onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
- onnxruntime/transformers/fusion_group_norm.py +180 -0
- onnxruntime/transformers/fusion_layernorm.py +489 -0
- onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
- onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
- onnxruntime/transformers/fusion_options.py +340 -0
- onnxruntime/transformers/fusion_qordered_attention.py +420 -0
- onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
- onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
- onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
- onnxruntime/transformers/fusion_quickgelu.py +74 -0
- onnxruntime/transformers/fusion_reshape.py +173 -0
- onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
- onnxruntime/transformers/fusion_shape.py +109 -0
- onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
- onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
- onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
- onnxruntime/transformers/fusion_transpose.py +167 -0
- onnxruntime/transformers/fusion_utils.py +321 -0
- onnxruntime/transformers/huggingface_models.py +74 -0
- onnxruntime/transformers/import_utils.py +20 -0
- onnxruntime/transformers/io_binding_helper.py +487 -0
- onnxruntime/transformers/large_model_exporter.py +395 -0
- onnxruntime/transformers/machine_info.py +230 -0
- onnxruntime/transformers/metrics.py +163 -0
- onnxruntime/transformers/models/bart/__init__.py +12 -0
- onnxruntime/transformers/models/bart/export.py +98 -0
- onnxruntime/transformers/models/bert/__init__.py +12 -0
- onnxruntime/transformers/models/bert/eval_squad.py +329 -0
- onnxruntime/transformers/models/gpt2/__init__.py +12 -0
- onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
- onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
- onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
- onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
- onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
- onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
- onnxruntime/transformers/models/llama/__init__.py +12 -0
- onnxruntime/transformers/models/llama/benchmark.py +700 -0
- onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
- onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
- onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
- onnxruntime/transformers/models/llama/dist_settings.py +57 -0
- onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
- onnxruntime/transformers/models/llama/llama_parity.py +343 -0
- onnxruntime/transformers/models/llama/llama_torch.py +47 -0
- onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
- onnxruntime/transformers/models/longformer/__init__.py +12 -0
- onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
- onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
- onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
- onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
- onnxruntime/transformers/models/phi2/__init__.py +12 -0
- onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
- onnxruntime/transformers/models/phi2/inference_example.py +414 -0
- onnxruntime/transformers/models/sam2/__init__.py +12 -0
- onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
- onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
- onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
- onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
- onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
- onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
- onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
- onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
- onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
- onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
- onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
- onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
- onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
- onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
- onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
- onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
- onnxruntime/transformers/models/t5/__init__.py +12 -0
- onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
- onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
- onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
- onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
- onnxruntime/transformers/models/t5/t5_helper.py +302 -0
- onnxruntime/transformers/models/whisper/__init__.py +12 -0
- onnxruntime/transformers/models/whisper/benchmark.py +585 -0
- onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
- onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
- onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
- onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
- onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
- onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
- onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
- onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
- onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
- onnxruntime/transformers/onnx_exporter.py +719 -0
- onnxruntime/transformers/onnx_model.py +1636 -0
- onnxruntime/transformers/onnx_model_bart.py +141 -0
- onnxruntime/transformers/onnx_model_bert.py +488 -0
- onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
- onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
- onnxruntime/transformers/onnx_model_clip.py +42 -0
- onnxruntime/transformers/onnx_model_conformer.py +32 -0
- onnxruntime/transformers/onnx_model_gpt2.py +101 -0
- onnxruntime/transformers/onnx_model_mmdit.py +112 -0
- onnxruntime/transformers/onnx_model_phi.py +929 -0
- onnxruntime/transformers/onnx_model_sam2.py +137 -0
- onnxruntime/transformers/onnx_model_t5.py +985 -0
- onnxruntime/transformers/onnx_model_tnlr.py +226 -0
- onnxruntime/transformers/onnx_model_unet.py +258 -0
- onnxruntime/transformers/onnx_model_vae.py +42 -0
- onnxruntime/transformers/onnx_utils.py +55 -0
- onnxruntime/transformers/optimizer.py +620 -0
- onnxruntime/transformers/past_helper.py +149 -0
- onnxruntime/transformers/profile_result_processor.py +358 -0
- onnxruntime/transformers/profiler.py +434 -0
- onnxruntime/transformers/quantize_helper.py +76 -0
- onnxruntime/transformers/shape_infer_helper.py +121 -0
- onnxruntime/transformers/shape_optimizer.py +400 -0
- onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
- onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
- onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
- onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
- onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
- onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,810 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License.
|
|
4
|
+
# --------------------------------------------------------------------------
|
|
5
|
+
|
|
6
|
+
from logging import getLogger
|
|
7
|
+
|
|
8
|
+
from fusion_base import Fusion
|
|
9
|
+
from fusion_utils import FusionUtils
|
|
10
|
+
from onnx import NodeProto, TensorProto, helper
|
|
11
|
+
from onnx_model import OnnxModel
|
|
12
|
+
|
|
13
|
+
logger = getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class FusionEmbedLayerNoMask(Fusion):
|
|
17
|
+
"""
|
|
18
|
+
Fuse embedding layer into one node (EmbedLayerNormalization).
|
|
19
|
+
It supports the following model types: BERT, DistilBert, ALBert.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
def __init__(self, model: OnnxModel, description: str = "no mask"):
|
|
23
|
+
super().__init__(
|
|
24
|
+
model,
|
|
25
|
+
"EmbedLayerNormalization",
|
|
26
|
+
["LayerNormalization", "SkipLayerNormalization"],
|
|
27
|
+
description,
|
|
28
|
+
)
|
|
29
|
+
self.utils = FusionUtils(model)
|
|
30
|
+
self.shape_infer = None
|
|
31
|
+
self.shape_infer_done = False
|
|
32
|
+
|
|
33
|
+
# The following will be reset in each fuse call of FusionEmbedLayerNormalization
|
|
34
|
+
self.attention = None
|
|
35
|
+
self.embed_node = None
|
|
36
|
+
|
|
37
|
+
def match_two_gather(self, add: NodeProto) -> None | tuple[NodeProto, NodeProto]:
|
|
38
|
+
gather_0_path = self.model.match_parent_path(add, ["Gather"], [0])
|
|
39
|
+
if gather_0_path is None:
|
|
40
|
+
return None
|
|
41
|
+
|
|
42
|
+
gather_1_path = self.model.match_parent_path(add, ["Gather"], [1])
|
|
43
|
+
if gather_1_path is None:
|
|
44
|
+
return None
|
|
45
|
+
|
|
46
|
+
return gather_0_path[0], gather_1_path[0]
|
|
47
|
+
|
|
48
|
+
def check_attention_subgraph(
|
|
49
|
+
self,
|
|
50
|
+
layernorm: NodeProto,
|
|
51
|
+
input_name_to_nodes: dict[str, list[NodeProto]],
|
|
52
|
+
is_distil_bert: bool,
|
|
53
|
+
) -> bool:
|
|
54
|
+
"""Check that LayerNormalization has a child of Attention node or subgraph like Attention.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
layernorm (NodeProto): LayerNormalization node
|
|
58
|
+
input_name_to_nodes (Dict[str, List[NodeProto]]): map from input name to nodes
|
|
59
|
+
is_distil_bert (bool): whether it is DistilBert or not
|
|
60
|
+
|
|
61
|
+
Returns:
|
|
62
|
+
bool: whether there is Attention node or subgraph like Attention
|
|
63
|
+
"""
|
|
64
|
+
self.attention = self.model.find_first_child_by_type(
|
|
65
|
+
layernorm, "Attention", input_name_to_nodes, recursive=False
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
if self.attention is not None:
|
|
69
|
+
return True
|
|
70
|
+
|
|
71
|
+
if layernorm.output[0] not in input_name_to_nodes:
|
|
72
|
+
return False
|
|
73
|
+
children = input_name_to_nodes[layernorm.output[0]]
|
|
74
|
+
children_types = sorted([child.op_type for child in children])
|
|
75
|
+
|
|
76
|
+
# Try find MultiHeadAttention
|
|
77
|
+
if children_types == ["MatMul", "MatMul", "MatMul", "SkipLayerNormalization"]:
|
|
78
|
+
for node in children:
|
|
79
|
+
if node.op_type == "SkipLayerNormalization":
|
|
80
|
+
path1 = self.model.match_parent_path(
|
|
81
|
+
node,
|
|
82
|
+
["Add", "MatMul", "MultiHeadAttention", "MatMul"],
|
|
83
|
+
[None, None, 0, 0],
|
|
84
|
+
)
|
|
85
|
+
if path1 is not None and path1[-1].input[0] == layernorm.output[0]:
|
|
86
|
+
self.cross_attention = path1[2]
|
|
87
|
+
return True
|
|
88
|
+
|
|
89
|
+
# In case user disables attention fusion, check whether subgraph looks like Attention.
|
|
90
|
+
# For Albert, there is MatMul+Add after embedding layer before attention.
|
|
91
|
+
if len(children) == 1 and children[0].op_type == "MatMul" and children[0].output[0] in input_name_to_nodes:
|
|
92
|
+
grandchildren = input_name_to_nodes[children[0].output[0]]
|
|
93
|
+
if (
|
|
94
|
+
len(grandchildren) == 1
|
|
95
|
+
and grandchildren[0].op_type == "Add"
|
|
96
|
+
and grandchildren[0].output[0] in input_name_to_nodes
|
|
97
|
+
):
|
|
98
|
+
nodes = input_name_to_nodes[grandchildren[0].output[0]]
|
|
99
|
+
for node in nodes:
|
|
100
|
+
if node.op_type == "Attention":
|
|
101
|
+
self.attention = node
|
|
102
|
+
return True
|
|
103
|
+
children_types = sorted([child.op_type for child in nodes])
|
|
104
|
+
|
|
105
|
+
# Two Shape nodes might be merged by ORT
|
|
106
|
+
if is_distil_bert:
|
|
107
|
+
# SkipLayerNormailization might exist when model has been optimized by ORT first.
|
|
108
|
+
if (
|
|
109
|
+
children_types != ["MatMul", "MatMul", "MatMul", "Shape", "SkipLayerNormalization"]
|
|
110
|
+
and children_types != ["Add", "MatMul", "MatMul", "MatMul", "Shape", "Shape"]
|
|
111
|
+
and children_types != ["Add", "MatMul", "MatMul", "MatMul", "Shape"]
|
|
112
|
+
):
|
|
113
|
+
logger.debug("No Attention like subgraph in children of LayerNormalization")
|
|
114
|
+
return False
|
|
115
|
+
else:
|
|
116
|
+
if children_types != [
|
|
117
|
+
"Add",
|
|
118
|
+
"MatMul",
|
|
119
|
+
"MatMul",
|
|
120
|
+
"MatMul",
|
|
121
|
+
] and children_types != [
|
|
122
|
+
"MatMul",
|
|
123
|
+
"MatMul",
|
|
124
|
+
"MatMul",
|
|
125
|
+
"SkipLayerNormalization",
|
|
126
|
+
]:
|
|
127
|
+
logger.debug("No Attention like subgraph in children of LayerNormalization")
|
|
128
|
+
return False
|
|
129
|
+
|
|
130
|
+
return True
|
|
131
|
+
|
|
132
|
+
def match_position_embedding_distilbert(self, position_embedding_gather, input_ids, output_name_to_node):
|
|
133
|
+
""" Match position embedding path from input_ids to Gather for DistilBert.
|
|
134
|
+
|
|
135
|
+
Pattern is like the following:
|
|
136
|
+
(input_ids)
|
|
137
|
+
|
|
|
138
|
+
Shape
|
|
139
|
+
| \
|
|
140
|
+
| Gather (indices=1)
|
|
141
|
+
| |
|
|
142
|
+
| Cast (optional)
|
|
143
|
+
| |
|
|
144
|
+
| Range (start=0, end=*, delta=1)
|
|
145
|
+
| |
|
|
146
|
+
| Unsqueeze
|
|
147
|
+
| /
|
|
148
|
+
Expand
|
|
149
|
+
|
|
|
150
|
+
Gather
|
|
151
|
+
"""
|
|
152
|
+
# remove after tests pass
|
|
153
|
+
path1 = self.model.match_parent_path(position_embedding_gather, ["Expand", "Shape"], [1, 1])
|
|
154
|
+
if path1 is None:
|
|
155
|
+
path1 = self.model.match_parent_path(
|
|
156
|
+
position_embedding_gather,
|
|
157
|
+
["Expand", "Where", "Reshape", "Shape"],
|
|
158
|
+
[1, 1, 2, 0],
|
|
159
|
+
)
|
|
160
|
+
if path1 is None:
|
|
161
|
+
return False
|
|
162
|
+
|
|
163
|
+
expand, shape = path1[0], path1[-1]
|
|
164
|
+
if shape.input[0] != input_ids:
|
|
165
|
+
return False
|
|
166
|
+
|
|
167
|
+
_, path2, _ = self.model.match_parent_paths(
|
|
168
|
+
expand,
|
|
169
|
+
[
|
|
170
|
+
(["Unsqueeze", "Range", "Cast", "Gather", "Shape"], [0, 0, 1, 0, 0]),
|
|
171
|
+
(["Unsqueeze", "Range", "Gather", "Shape"], [0, 0, 1, 0]),
|
|
172
|
+
],
|
|
173
|
+
output_name_to_node,
|
|
174
|
+
)
|
|
175
|
+
if path2 is None:
|
|
176
|
+
return False
|
|
177
|
+
|
|
178
|
+
range_node = path2[1]
|
|
179
|
+
if not (
|
|
180
|
+
self.utils.check_node_input_value(range_node, 0, 0) and self.utils.check_node_input_value(range_node, 2, 1)
|
|
181
|
+
):
|
|
182
|
+
return False
|
|
183
|
+
|
|
184
|
+
gather_node = path2[-2]
|
|
185
|
+
if not (self.utils.check_node_input_value(gather_node, 1, 1)):
|
|
186
|
+
return False
|
|
187
|
+
|
|
188
|
+
shape_node = path2[-1]
|
|
189
|
+
if shape_node.input[0] != input_ids:
|
|
190
|
+
return False
|
|
191
|
+
|
|
192
|
+
return True
|
|
193
|
+
|
|
194
|
+
def match_position_embedding_roberta(self, position_embedding_gather, input_ids, output_name_to_node):
|
|
195
|
+
"""Match position embedding path from input_ids to Gather for Roberta.
|
|
196
|
+
|
|
197
|
+
Roberta Embedding Layer Pattern (* is optional since it might be removed by ORT, ? is the padding word id):
|
|
198
|
+
(input_ids) --> Equal(B=?) -- Not -- Cast(to=6) -- CumSum(axis=1) -- Mul -- Cast(to=7) -- Add(B=1) -- Cast(to=7)* --> Gather
|
|
199
|
+
| ^
|
|
200
|
+
V |
|
|
201
|
+
+------------------------------+
|
|
202
|
+
|
|
203
|
+
Roberta new pattern from transformers v4.9:
|
|
204
|
+
(input_ids) --> Equal(B=?) -- Not -- Cast(to=6) -- CumSum(axis=1) -- Add(B=0) -- Mul -- Cast(to=7) -- Add(B=1) --> Gather
|
|
205
|
+
| ^
|
|
206
|
+
V |
|
|
207
|
+
+-------------------------------------------+
|
|
208
|
+
|
|
209
|
+
start_node = position_embedding_gather
|
|
210
|
+
start_index = 1
|
|
211
|
+
|
|
212
|
+
# match optional Cast node.
|
|
213
|
+
parent = self.model.get_parent(start_node, start_index, output_name_to_node)
|
|
214
|
+
if parent is None:
|
|
215
|
+
return
|
|
216
|
+
if parent.op_type == "Cast":
|
|
217
|
+
if OnnxModel.get_node_attribute(parent, "to") != 7:
|
|
218
|
+
return
|
|
219
|
+
start_node = parent
|
|
220
|
+
start_index = 0
|
|
221
|
+
|
|
222
|
+
i, path, return_indices = self.model.match_parent_paths(
|
|
223
|
+
start_node,
|
|
224
|
+
[ (['Add', 'Cast', 'Mul', 'CumSum', 'Cast', 'Not', 'Equal'], [start_index, 0, 0, 0, 0, 0, 0]),
|
|
225
|
+
(['Add', 'Cast', 'Mul', 'Add', 'CumSum', 'Cast', 'Not', 'Equal'], [start_index, 0, 0, 0, 0, 0, 0, 0])],
|
|
226
|
+
output_name_to_node)
|
|
227
|
+
|
|
228
|
+
if path is not None:
|
|
229
|
+
# constant input of Add shall be 1.
|
|
230
|
+
i, value = self.model.get_constant_input(path[0])
|
|
231
|
+
if value != 1:
|
|
232
|
+
return False
|
|
233
|
+
|
|
234
|
+
_, self.padding_word_id = self.model.get_constant_input(path[-1])
|
|
235
|
+
|
|
236
|
+
return input_ids == path[-1].input[0]
|
|
237
|
+
"""
|
|
238
|
+
|
|
239
|
+
return False
|
|
240
|
+
|
|
241
|
+
def match_position_embedding_bert(self, position_embedding_gather, input_ids, output_name_to_node):
|
|
242
|
+
""" Match position embedding path from input_ids to Gather for BERT.
|
|
243
|
+
|
|
244
|
+
BERT Embedding Layer Pattern:
|
|
245
|
+
(input_ids)
|
|
246
|
+
/ \
|
|
247
|
+
/ Shape
|
|
248
|
+
/ |
|
|
249
|
+
/ Gather (indices=1)
|
|
250
|
+
/ |
|
|
251
|
+
/ Add (optional, B=0)
|
|
252
|
+
/ |
|
|
253
|
+
Gather (segment_ids) Unsqueeze (axes=0)
|
|
254
|
+
\\ | |
|
|
255
|
+
\\ Gather Slice (data[1,512], starts=0, ends=*, axes=1, steps=1)
|
|
256
|
+
\\ / |
|
|
257
|
+
Add Gather
|
|
258
|
+
\\ /
|
|
259
|
+
Add
|
|
260
|
+
|
|
|
261
|
+
LayerNormalization
|
|
262
|
+
"""
|
|
263
|
+
path = self.model.match_parent_path(
|
|
264
|
+
position_embedding_gather,
|
|
265
|
+
["Slice", "Unsqueeze"],
|
|
266
|
+
[1, 2],
|
|
267
|
+
output_name_to_node,
|
|
268
|
+
)
|
|
269
|
+
if path is None:
|
|
270
|
+
return False
|
|
271
|
+
|
|
272
|
+
slice, unsqueeze = path
|
|
273
|
+
slice_weight = self.model.get_constant_value(slice.input[0])
|
|
274
|
+
if not (
|
|
275
|
+
slice_weight is not None
|
|
276
|
+
and len(slice_weight.shape) == 2
|
|
277
|
+
and slice_weight.shape[0] == 1
|
|
278
|
+
and self.utils.check_node_input_value(slice, 1, [0])
|
|
279
|
+
and self.utils.check_node_input_value(slice, 3, [1])
|
|
280
|
+
and (len(slice.input) == 4 or self.utils.check_node_input_value(slice, 4, [1]))
|
|
281
|
+
):
|
|
282
|
+
return False
|
|
283
|
+
|
|
284
|
+
opset_version = self.model.get_opset_version()
|
|
285
|
+
if opset_version < 13:
|
|
286
|
+
if not FusionUtils.check_node_attribute(unsqueeze, "axes", [0]):
|
|
287
|
+
return False
|
|
288
|
+
else:
|
|
289
|
+
if not self.utils.check_node_input_value(unsqueeze, 1, [0]):
|
|
290
|
+
return False
|
|
291
|
+
|
|
292
|
+
node = self.model.get_parent(unsqueeze, 0, output_name_to_node)
|
|
293
|
+
if node is None:
|
|
294
|
+
return False
|
|
295
|
+
if node.op_type == "Add":
|
|
296
|
+
if not self.utils.check_node_input_value(node, 1, 0):
|
|
297
|
+
return False
|
|
298
|
+
gather = self.model.get_parent(node, 0, output_name_to_node)
|
|
299
|
+
else:
|
|
300
|
+
gather = node
|
|
301
|
+
|
|
302
|
+
if gather is None or gather.op_type != "Gather":
|
|
303
|
+
return False
|
|
304
|
+
if not (self.utils.check_node_input_value(gather, 1, 1)):
|
|
305
|
+
return False
|
|
306
|
+
|
|
307
|
+
shape = self.model.get_parent(gather, 0, output_name_to_node)
|
|
308
|
+
if shape is None or shape.op_type != "Shape":
|
|
309
|
+
return False
|
|
310
|
+
|
|
311
|
+
return input_ids == shape.input[0]
|
|
312
|
+
|
|
313
|
+
def match_position_embedding(self, position_embedding_gather, input_ids, output_name_to_node):
|
|
314
|
+
if self.match_position_embedding_bert(position_embedding_gather, input_ids, output_name_to_node):
|
|
315
|
+
return True
|
|
316
|
+
|
|
317
|
+
# TODO: Support roberta (position starts from 2 instead of 0) in EmbedLayerNormalization kernel
|
|
318
|
+
# related: https://github.com/huggingface/transformers/issues/10736
|
|
319
|
+
# if self.match_position_embedding_roberta(position_embedding_gather, input_ids, output_name_to_node):
|
|
320
|
+
# return True
|
|
321
|
+
|
|
322
|
+
if self.match_position_embedding_distilbert(position_embedding_gather, input_ids, output_name_to_node):
|
|
323
|
+
return True
|
|
324
|
+
|
|
325
|
+
return False
|
|
326
|
+
|
|
327
|
+
def check_embedding(self, word_embedding_gather, segment_embedding_gather, position_embedding_gather):
|
|
328
|
+
"""Sanity check of embedding weights, and match hidden_size of weights and shape of inputs."""
|
|
329
|
+
input_ids = word_embedding_gather.input[1]
|
|
330
|
+
segment_ids = segment_embedding_gather.input[1] if segment_embedding_gather else None
|
|
331
|
+
position_ids = position_embedding_gather.input[1]
|
|
332
|
+
|
|
333
|
+
if not self.shape_infer_done:
|
|
334
|
+
self.shape_infer = self.model.infer_runtime_shape(update=True)
|
|
335
|
+
self.shape_infer_done = True
|
|
336
|
+
|
|
337
|
+
if self.shape_infer is not None:
|
|
338
|
+
input_ids_shape = self.shape_infer.get_edge_shape(input_ids)
|
|
339
|
+
position_ids_shape = self.shape_infer.get_edge_shape(position_ids)
|
|
340
|
+
assert input_ids_shape and position_ids_shape
|
|
341
|
+
if not (
|
|
342
|
+
len(input_ids_shape) == 2
|
|
343
|
+
and len(position_ids_shape) == 2
|
|
344
|
+
and input_ids_shape[1] == position_ids_shape[1]
|
|
345
|
+
):
|
|
346
|
+
logger.info(
|
|
347
|
+
f"Cannot fuse EmbedLayerNormalization: input_ids and position_ids not matched in 2nd dimension: {input_ids_shape} vs {position_ids_shape}"
|
|
348
|
+
)
|
|
349
|
+
return False
|
|
350
|
+
|
|
351
|
+
if segment_ids and not self.shape_infer.compare_shape(input_ids, segment_ids):
|
|
352
|
+
logger.info(
|
|
353
|
+
f"Cannot fuse EmbedLayerNormalization: input_ids and segment_ids does not have same shape: {input_ids_shape} != {self.shape_infer.get_edge_shape(segment_ids)}"
|
|
354
|
+
)
|
|
355
|
+
return False
|
|
356
|
+
|
|
357
|
+
word_embedding_table = self.model.get_constant_value(word_embedding_gather.input[0])
|
|
358
|
+
if word_embedding_table is None or len(word_embedding_table.shape) != 2:
|
|
359
|
+
logger.info("Cannot fuse EmbedLayerNormalization: word embedding table is not expected")
|
|
360
|
+
return False
|
|
361
|
+
|
|
362
|
+
position_embedding_table = self.model.get_constant_value(position_embedding_gather.input[0])
|
|
363
|
+
if (
|
|
364
|
+
position_embedding_table is None
|
|
365
|
+
or len(position_embedding_table.shape) != 2
|
|
366
|
+
or (word_embedding_table.shape[1] != position_embedding_table.shape[1])
|
|
367
|
+
):
|
|
368
|
+
logger.info("Cannot fuse EmbedLayerNormalization: position embedding table is not expected")
|
|
369
|
+
return False
|
|
370
|
+
|
|
371
|
+
if segment_ids:
|
|
372
|
+
segment_embedding_table = self.model.get_constant_value(segment_embedding_gather.input[0])
|
|
373
|
+
if (
|
|
374
|
+
segment_embedding_table is None
|
|
375
|
+
or len(segment_embedding_table.shape) != 2
|
|
376
|
+
or (word_embedding_table.shape[1] != segment_embedding_table.shape[1])
|
|
377
|
+
):
|
|
378
|
+
logger.info("Cannot fuse EmbedLayerNormalization: segment embedding table is not expected")
|
|
379
|
+
return False
|
|
380
|
+
|
|
381
|
+
# In normal case, word embedding table is the largest, and segment embedding table is the smallest, while position embedding table is in between.
|
|
382
|
+
# TODO: use other information (like initializer names) to identify different embedding weights automatically.
|
|
383
|
+
if word_embedding_table.shape[0] <= position_embedding_table.shape[0]:
|
|
384
|
+
logger.warning(
|
|
385
|
+
f"word_embedding_table ({word_embedding_gather.input[0]}) size {word_embedding_table.shape[0]} <= position_embedding_table ({position_embedding_gather.input[0]}) size {position_embedding_table.shape[0]}"
|
|
386
|
+
)
|
|
387
|
+
|
|
388
|
+
if segment_ids:
|
|
389
|
+
if word_embedding_table.shape[0] <= segment_embedding_table.shape[0]:
|
|
390
|
+
logger.warning(
|
|
391
|
+
f"word_embedding_table ({word_embedding_gather.input[0]}) size {word_embedding_table.shape[0]} <= segment_embedding_table ({segment_embedding_gather.input[0]}) size {segment_embedding_table.shape[0]}"
|
|
392
|
+
)
|
|
393
|
+
|
|
394
|
+
if position_embedding_table.shape[0] <= segment_embedding_table.shape[0]:
|
|
395
|
+
logger.warning(
|
|
396
|
+
f"position_embedding_table ({position_embedding_gather.input[0]}) size {position_embedding_table.shape[0]} <= segment_embedding_table ({segment_embedding_gather.input[0]}) size {segment_embedding_table.shape[0]}"
|
|
397
|
+
)
|
|
398
|
+
|
|
399
|
+
return True
|
|
400
|
+
|
|
401
|
+
def cast_to_int32(self, input_name: str) -> tuple[str, None | NodeProto]:
|
|
402
|
+
"""Cast a graph input or node input to int32.
|
|
403
|
+
|
|
404
|
+
Args:
|
|
405
|
+
input_name (str): name of graph input or node input
|
|
406
|
+
|
|
407
|
+
Returns:
|
|
408
|
+
A tuple of casted input name and the cast node.
|
|
409
|
+
int32_output (str): If input is int32, it is the input name, Otherwise it is output name of Cast node.
|
|
410
|
+
input_cast_node (Union[None, NodeProto]): Cast node. It could be None if input is int32.
|
|
411
|
+
"""
|
|
412
|
+
input_cast_node = None
|
|
413
|
+
graph_input = self.model.find_graph_input(input_name)
|
|
414
|
+
if graph_input is not None:
|
|
415
|
+
if graph_input.type.tensor_type.elem_type != TensorProto.INT32:
|
|
416
|
+
int32_output, input_cast_node = self.utils.cast_input_to_int32(input_name)
|
|
417
|
+
else:
|
|
418
|
+
int32_output = input_name
|
|
419
|
+
else:
|
|
420
|
+
int32_output, input_cast_node = self.utils.cast_input_to_int32(input_name)
|
|
421
|
+
|
|
422
|
+
return int32_output, input_cast_node
|
|
423
|
+
|
|
424
|
+
def create_fused_node(
|
|
425
|
+
self,
|
|
426
|
+
input_ids: str,
|
|
427
|
+
layernorm: NodeProto,
|
|
428
|
+
word_embedding_gather: NodeProto,
|
|
429
|
+
position_embedding_gather: NodeProto,
|
|
430
|
+
segment_embedding_gather: None | NodeProto,
|
|
431
|
+
position_ids: str | None = None,
|
|
432
|
+
embedding_sum_output=False,
|
|
433
|
+
embedding_sum_name=None,
|
|
434
|
+
):
|
|
435
|
+
"""Create an EmbedLayerNormalization node. Note that segment embedding is optional.
|
|
436
|
+
|
|
437
|
+
Args:
|
|
438
|
+
input_ids (str): input_ids for word embeddings
|
|
439
|
+
layernorm (NodeProto): LayerNormalization or SkipLayerNormalization node.
|
|
440
|
+
word_embedding_gather (NodeProto): the Gather node for word embedding
|
|
441
|
+
position_embedding_gather (NodeProto): the Gather node for position embedding
|
|
442
|
+
segment_embedding_gather (Union[None, NodeProto]): the Gather node for segment embedding, or None.
|
|
443
|
+
|
|
444
|
+
Returns:
|
|
445
|
+
NodeProto: the EmbedLayerNormalization node created.
|
|
446
|
+
"""
|
|
447
|
+
nodes_to_add = []
|
|
448
|
+
input_ids, _ = self.cast_to_int32(input_ids)
|
|
449
|
+
|
|
450
|
+
node_name = self.model.create_node_name("EmbedLayerNormalization")
|
|
451
|
+
|
|
452
|
+
if layernorm.op_type == "LayerNormalization":
|
|
453
|
+
gamma = layernorm.input[1]
|
|
454
|
+
beta = layernorm.input[2]
|
|
455
|
+
else: # SkipLayerNormalization
|
|
456
|
+
gamma = layernorm.input[2]
|
|
457
|
+
beta = layernorm.input[3]
|
|
458
|
+
|
|
459
|
+
embed_node_inputs = None
|
|
460
|
+
if segment_embedding_gather is not None:
|
|
461
|
+
segment_ids, _ = self.cast_to_int32(segment_embedding_gather.input[1])
|
|
462
|
+
|
|
463
|
+
embed_node_inputs = [
|
|
464
|
+
input_ids,
|
|
465
|
+
segment_ids,
|
|
466
|
+
word_embedding_gather.input[0],
|
|
467
|
+
position_embedding_gather.input[0],
|
|
468
|
+
segment_embedding_gather.input[0],
|
|
469
|
+
gamma,
|
|
470
|
+
beta,
|
|
471
|
+
]
|
|
472
|
+
else: # no segment embedding
|
|
473
|
+
embed_node_inputs = [
|
|
474
|
+
input_ids,
|
|
475
|
+
"",
|
|
476
|
+
word_embedding_gather.input[0],
|
|
477
|
+
position_embedding_gather.input[0],
|
|
478
|
+
"",
|
|
479
|
+
gamma,
|
|
480
|
+
beta,
|
|
481
|
+
]
|
|
482
|
+
|
|
483
|
+
if position_ids is not None:
|
|
484
|
+
# Adding an empty input for mask before position_ids
|
|
485
|
+
embed_node_inputs.append("")
|
|
486
|
+
position_ids, _ = self.cast_to_int32(position_ids)
|
|
487
|
+
embed_node_inputs.append(position_ids)
|
|
488
|
+
|
|
489
|
+
embed_node_outputs = [node_name + "_output", node_name + "_dummy_mask_index"]
|
|
490
|
+
if embedding_sum_output:
|
|
491
|
+
name = embedding_sum_name if embedding_sum_name is not None else node_name + "_embedding_sum"
|
|
492
|
+
embed_node_outputs.append(name)
|
|
493
|
+
|
|
494
|
+
embed_node = helper.make_node(
|
|
495
|
+
"EmbedLayerNormalization",
|
|
496
|
+
embed_node_inputs,
|
|
497
|
+
outputs=embed_node_outputs,
|
|
498
|
+
name=node_name,
|
|
499
|
+
)
|
|
500
|
+
|
|
501
|
+
embed_node.domain = "com.microsoft"
|
|
502
|
+
|
|
503
|
+
# Pass attribute "epsilon" from normalize node to EmbedLayerNormalization.
|
|
504
|
+
for att in layernorm.attribute:
|
|
505
|
+
if att.name == "epsilon":
|
|
506
|
+
embed_node.attribute.extend([att])
|
|
507
|
+
|
|
508
|
+
# Set default value to 1e-12 if no attribute is found.
|
|
509
|
+
# OnnxRuntime 1.2.0 or older has no epsilon attribute. The optimized model can only work for 1.3.0 or later.
|
|
510
|
+
if len(embed_node.attribute) == 0:
|
|
511
|
+
embed_node.attribute.extend([helper.make_attribute("epsilon", 1.0e-12)])
|
|
512
|
+
|
|
513
|
+
# Make sure new EmbedLayerNormalization node is the last one in self.nodes_to_add.
|
|
514
|
+
nodes_to_add.append(embed_node)
|
|
515
|
+
for node in nodes_to_add:
|
|
516
|
+
self.node_name_to_graph_name[node.name] = self.this_graph_name
|
|
517
|
+
self.nodes_to_add.extend(nodes_to_add)
|
|
518
|
+
|
|
519
|
+
self.embed_node = embed_node
|
|
520
|
+
return embed_node
|
|
521
|
+
|
|
522
|
+
def finish_fusion(self, layernorm, embed_node):
|
|
523
|
+
self.model.replace_input_of_all_nodes(layernorm.output[0], embed_node.output[0])
|
|
524
|
+
# use prune graph to remove nodes that is not needed
|
|
525
|
+
self.prune_graph = True
|
|
526
|
+
|
|
527
|
+
def is_skip_layer_norm_with_sum_output(self, node):
|
|
528
|
+
return (node.op_type == "SkipLayerNormalization") and len(node.output) > 3 and len(node.output[3]) > 0
|
|
529
|
+
|
|
530
|
+
def fuse_gpt2(
|
|
531
|
+
self, layernorm, add_before_layernorm, input_name_to_nodes, output_name_to_node, optional_segment_gather=None
|
|
532
|
+
):
|
|
533
|
+
# graph checks
|
|
534
|
+
# gpt2 has optional segment embedding, subgraph pattern is like
|
|
535
|
+
# input_ids position_ids
|
|
536
|
+
# | |
|
|
537
|
+
# token_ids Gather Gather
|
|
538
|
+
# | \ /
|
|
539
|
+
# Gather (optional) Add _ _ _ _ _
|
|
540
|
+
# \ | |
|
|
541
|
+
# LayerNormalization |
|
|
542
|
+
# | |
|
|
543
|
+
# Attention |
|
|
544
|
+
# | |
|
|
545
|
+
# Matmul |
|
|
546
|
+
# | /
|
|
547
|
+
# Add /
|
|
548
|
+
# \ /
|
|
549
|
+
# Add
|
|
550
|
+
two_gather = self.match_two_gather(add_before_layernorm)
|
|
551
|
+
if two_gather is None:
|
|
552
|
+
return False
|
|
553
|
+
|
|
554
|
+
word_embedding_gather, position_embedding_gather = two_gather
|
|
555
|
+
input_ids = word_embedding_gather.input[1]
|
|
556
|
+
position_ids = position_embedding_gather.input[1]
|
|
557
|
+
|
|
558
|
+
if not self.check_attention_subgraph(layernorm, input_name_to_nodes, is_distil_bert=False):
|
|
559
|
+
return False
|
|
560
|
+
|
|
561
|
+
if not self.check_embedding(word_embedding_gather, None, position_embedding_gather):
|
|
562
|
+
return False
|
|
563
|
+
|
|
564
|
+
# If layernorm node is SkipLayerNormalization, we need look at its optional fourth output.
|
|
565
|
+
# If the add_before_layernorm node is an Add node, then the add_output output is the first output of this node.
|
|
566
|
+
# If the add_before_layernorm node is a SkipLayerNormalization node, then the add_output output
|
|
567
|
+
# is the (optional) fourth index output of this node.
|
|
568
|
+
# When add_before_layernorm is SkipLayerNormalization, add_before_layernorm and layernorm are same node.
|
|
569
|
+
if layernorm.op_type == "SkipLayerNormalization":
|
|
570
|
+
need_embedding_sum_output = self.is_skip_layer_norm_with_sum_output(layernorm)
|
|
571
|
+
sum_output_index = 3
|
|
572
|
+
node_with_sum_output = layernorm
|
|
573
|
+
sum_output = layernorm.output[3] if need_embedding_sum_output else None
|
|
574
|
+
is_sum_graph_output = (sum_output is not None) and (self.model.find_graph_output(sum_output) is not None)
|
|
575
|
+
else: # layernorm.op_type == "LayerNormalization"
|
|
576
|
+
node_with_sum_output = add_before_layernorm
|
|
577
|
+
sum_output_index = 0 if add_before_layernorm.op_type == "Add" else 3
|
|
578
|
+
sum_output = (
|
|
579
|
+
add_before_layernorm.output[sum_output_index]
|
|
580
|
+
if len(add_before_layernorm.output) > sum_output_index
|
|
581
|
+
else None
|
|
582
|
+
)
|
|
583
|
+
is_sum_graph_output = (sum_output is not None) and (self.model.find_graph_output(sum_output) is not None)
|
|
584
|
+
is_sum_used_by_multiple_nodes = (
|
|
585
|
+
sum_output and (sum_output in input_name_to_nodes) and len(input_name_to_nodes[sum_output]) > 1
|
|
586
|
+
)
|
|
587
|
+
need_embedding_sum_output = (sum_output is not None) and (
|
|
588
|
+
add_before_layernorm.op_type != "Add" or is_sum_graph_output or is_sum_used_by_multiple_nodes
|
|
589
|
+
)
|
|
590
|
+
|
|
591
|
+
# make the fused node
|
|
592
|
+
embed_node = self.create_fused_node(
|
|
593
|
+
input_ids,
|
|
594
|
+
layernorm,
|
|
595
|
+
word_embedding_gather,
|
|
596
|
+
position_embedding_gather,
|
|
597
|
+
optional_segment_gather,
|
|
598
|
+
position_ids,
|
|
599
|
+
embedding_sum_output=need_embedding_sum_output,
|
|
600
|
+
embedding_sum_name=sum_output if is_sum_graph_output else None,
|
|
601
|
+
)
|
|
602
|
+
|
|
603
|
+
if need_embedding_sum_output:
|
|
604
|
+
node_with_sum_output.output[sum_output_index] = "_no_use__to_be_removed_"
|
|
605
|
+
if not is_sum_graph_output:
|
|
606
|
+
self.model.replace_input_of_all_nodes(sum_output, embed_node.output[2])
|
|
607
|
+
|
|
608
|
+
self.finish_fusion(layernorm, embed_node)
|
|
609
|
+
return True
|
|
610
|
+
|
|
611
|
+
def fuse_distilbert(self, layernorm, add_before_layernorm, input_name_to_nodes, output_name_to_node):
|
|
612
|
+
"""Fuse embedding layer for DistilBert
|
|
613
|
+
Args:
|
|
614
|
+
layernorm (NodeProto): node of LayerNormalization or SkipLayerNormalization
|
|
615
|
+
add_before_layernorm (NodeProto): the Add node before LayerNormalization, or the SkipLayerNormalization itself
|
|
616
|
+
input_name_to_nodes (Dict[str, List[NodeProto]]): map from input name to nodes
|
|
617
|
+
output_name_to_node (Dict[str, List[NodeProto]]): map from output name to nodes
|
|
618
|
+
"""
|
|
619
|
+
|
|
620
|
+
# DistilBert has no segment embedding, subgraph pattern is like
|
|
621
|
+
# input_ids
|
|
622
|
+
# | \
|
|
623
|
+
# | (position_embedding_subgraph)
|
|
624
|
+
# | |
|
|
625
|
+
# Gather Gather
|
|
626
|
+
# \ /
|
|
627
|
+
# Add
|
|
628
|
+
# |
|
|
629
|
+
# LayerNormalization
|
|
630
|
+
two_gather = self.match_two_gather(add_before_layernorm)
|
|
631
|
+
if two_gather is None:
|
|
632
|
+
return False
|
|
633
|
+
|
|
634
|
+
word_embedding_gather, position_embedding_gather = two_gather
|
|
635
|
+
input_ids = word_embedding_gather.input[1]
|
|
636
|
+
|
|
637
|
+
if not self.check_attention_subgraph(layernorm, input_name_to_nodes, is_distil_bert=True):
|
|
638
|
+
return False
|
|
639
|
+
|
|
640
|
+
if not self.match_position_embedding(position_embedding_gather, input_ids, output_name_to_node):
|
|
641
|
+
return False
|
|
642
|
+
|
|
643
|
+
if not self.check_embedding(word_embedding_gather, None, position_embedding_gather):
|
|
644
|
+
return False
|
|
645
|
+
|
|
646
|
+
embed_node = self.create_fused_node(
|
|
647
|
+
input_ids, layernorm, word_embedding_gather, position_embedding_gather, None
|
|
648
|
+
)
|
|
649
|
+
self.finish_fusion(layernorm, embed_node)
|
|
650
|
+
return True
|
|
651
|
+
|
|
652
|
+
def fuse_bert(self, layernorm, add_before_layernorm, input_name_to_nodes, output_name_to_node):
|
|
653
|
+
"""Fuse embedding layer for Bert
|
|
654
|
+
Args:
|
|
655
|
+
layernorm (NodeProto): node of LayerNormalization or SkipLayerNormalization
|
|
656
|
+
add_before_layernorm (NodeProto): the Add node before LayerNormalization, or the SkipLayerNormalization itself
|
|
657
|
+
input_name_to_nodes (Dict[str, List[NodeProto]]): map from input name to nodes
|
|
658
|
+
output_name_to_node (Dict[str, List[NodeProto]]): map from output name to nodes
|
|
659
|
+
"""
|
|
660
|
+
|
|
661
|
+
add_2_gather = self.model.match_parent_path(add_before_layernorm, ["Add"], [0])
|
|
662
|
+
if add_2_gather is None:
|
|
663
|
+
return False
|
|
664
|
+
|
|
665
|
+
two_gather = self.match_two_gather(add_2_gather[0])
|
|
666
|
+
if two_gather is None:
|
|
667
|
+
return False
|
|
668
|
+
|
|
669
|
+
word_embedding_gather, segment_embedding_gather = two_gather
|
|
670
|
+
|
|
671
|
+
input_ids = word_embedding_gather.input[1]
|
|
672
|
+
|
|
673
|
+
if not self.check_attention_subgraph(layernorm, input_name_to_nodes, is_distil_bert=False):
|
|
674
|
+
return False
|
|
675
|
+
|
|
676
|
+
position_embedding_path = self.model.match_parent_path(add_before_layernorm, ["Gather"], [1])
|
|
677
|
+
if position_embedding_path is None:
|
|
678
|
+
return False
|
|
679
|
+
|
|
680
|
+
position_embedding_gather = position_embedding_path[0]
|
|
681
|
+
if not self.match_position_embedding(position_embedding_gather, input_ids, output_name_to_node):
|
|
682
|
+
if not self.match_position_embedding(segment_embedding_gather, input_ids, output_name_to_node):
|
|
683
|
+
return False
|
|
684
|
+
# position and segment are switched
|
|
685
|
+
temp = segment_embedding_gather
|
|
686
|
+
segment_embedding_gather = position_embedding_gather
|
|
687
|
+
position_embedding_gather = temp
|
|
688
|
+
|
|
689
|
+
if not self.check_embedding(word_embedding_gather, segment_embedding_gather, position_embedding_gather):
|
|
690
|
+
return False
|
|
691
|
+
|
|
692
|
+
embed_node = self.create_fused_node(
|
|
693
|
+
input_ids,
|
|
694
|
+
layernorm,
|
|
695
|
+
word_embedding_gather,
|
|
696
|
+
position_embedding_gather,
|
|
697
|
+
segment_embedding_gather,
|
|
698
|
+
)
|
|
699
|
+
self.finish_fusion(layernorm, embed_node)
|
|
700
|
+
return True
|
|
701
|
+
|
|
702
|
+
def fuse(self, node, input_name_to_nodes, output_name_to_node):
|
|
703
|
+
first_add_path = self.model.match_parent_path(node, ["Add"], [0])
|
|
704
|
+
if node.op_type == "LayerNormalization":
|
|
705
|
+
if first_add_path is None:
|
|
706
|
+
return
|
|
707
|
+
add_before_layernorm = first_add_path[0]
|
|
708
|
+
optional_segment_gather = None
|
|
709
|
+
else: # SkipLayerNormalization
|
|
710
|
+
gather_0_path = self.model.match_parent_path(node, ["Gather"], [0])
|
|
711
|
+
gather_1_path = self.model.match_parent_path(node, ["Gather"], [1])
|
|
712
|
+
if gather_0_path is None and gather_1_path is not None:
|
|
713
|
+
if first_add_path is None:
|
|
714
|
+
return
|
|
715
|
+
add_before_layernorm = first_add_path[0]
|
|
716
|
+
optional_segment_gather = gather_1_path[0]
|
|
717
|
+
elif gather_0_path is not None and gather_1_path is None:
|
|
718
|
+
first_add_path = self.model.match_parent_path(node, ["Add"], [1])
|
|
719
|
+
if first_add_path is None:
|
|
720
|
+
return
|
|
721
|
+
add_before_layernorm = first_add_path[0]
|
|
722
|
+
optional_segment_gather = gather_0_path[0]
|
|
723
|
+
else:
|
|
724
|
+
add_before_layernorm = node # Add is fused into SkipLayerNormalization
|
|
725
|
+
optional_segment_gather = None
|
|
726
|
+
|
|
727
|
+
if self.fuse_gpt2(
|
|
728
|
+
node, add_before_layernorm, input_name_to_nodes, output_name_to_node, optional_segment_gather
|
|
729
|
+
):
|
|
730
|
+
return
|
|
731
|
+
|
|
732
|
+
if self.fuse_distilbert(node, add_before_layernorm, input_name_to_nodes, output_name_to_node):
|
|
733
|
+
return
|
|
734
|
+
|
|
735
|
+
if self.fuse_bert(node, add_before_layernorm, input_name_to_nodes, output_name_to_node):
|
|
736
|
+
return
|
|
737
|
+
|
|
738
|
+
|
|
739
|
+
class FusionEmbedLayerNormalization(FusionEmbedLayerNoMask):
|
|
740
|
+
def __init__(self, model: OnnxModel, use_mask_index=False):
|
|
741
|
+
super().__init__(model, "with mask")
|
|
742
|
+
self.use_mask_index = use_mask_index
|
|
743
|
+
|
|
744
|
+
def replace_mask(self, mask_int32, attention_nodes):
|
|
745
|
+
# Inputs of EmbedLayerNorm: input_ids, segment_ids (optional), word_embedding, position_embedding,
|
|
746
|
+
# segment_embedding (optional), gamma, beta, mask (optional), position_ids (optional)
|
|
747
|
+
embed_node = self.embed_node
|
|
748
|
+
if len(embed_node.input) == 7:
|
|
749
|
+
embed_node.input.append(mask_int32)
|
|
750
|
+
logger.debug("append mask to %s", embed_node.name)
|
|
751
|
+
elif len(embed_node.input) > 7 and not embed_node.input[7]:
|
|
752
|
+
embed_node.input[7] = mask_int32
|
|
753
|
+
logger.debug("replace mask in %s", embed_node.name)
|
|
754
|
+
else:
|
|
755
|
+
logger.debug("skip mask in %s", embed_node.name)
|
|
756
|
+
return
|
|
757
|
+
|
|
758
|
+
for attention_node in attention_nodes:
|
|
759
|
+
logger.debug("update mask_index in %s", attention_node.name)
|
|
760
|
+
if attention_node.op_type == "Attention":
|
|
761
|
+
attention_node.input[3] = embed_node.output[1]
|
|
762
|
+
elif attention_node.op_type == "MultiHeadAttention":
|
|
763
|
+
attention_node.input[4] = embed_node.output[1]
|
|
764
|
+
|
|
765
|
+
def fuse(self, node, input_name_to_nodes, output_name_to_node):
|
|
766
|
+
# Reset attention and embed_node so that we know fusion is successful when they are not None.
|
|
767
|
+
self.attention = None
|
|
768
|
+
self.cross_attention = None
|
|
769
|
+
self.embed_node = None
|
|
770
|
+
super().fuse(node, input_name_to_nodes, output_name_to_node)
|
|
771
|
+
|
|
772
|
+
if self.embed_node is None:
|
|
773
|
+
return
|
|
774
|
+
|
|
775
|
+
if not self.use_mask_index:
|
|
776
|
+
logger.debug("--use_mask_index is not set: EmbedLayerNormalization will not have mask")
|
|
777
|
+
self.increase_counter("EmbedLayerNormalization(no mask)")
|
|
778
|
+
return
|
|
779
|
+
|
|
780
|
+
if self.attention is None and self.cross_attention is None:
|
|
781
|
+
logger.debug("EmbedLayerNormalization will not have mask since attention node is not found")
|
|
782
|
+
self.increase_counter("EmbedLayerNormalization(no mask)")
|
|
783
|
+
return
|
|
784
|
+
|
|
785
|
+
if self.attention:
|
|
786
|
+
mask_int32 = self.attention.input[3]
|
|
787
|
+
else:
|
|
788
|
+
mask_int32 = self.cross_attention.input[4]
|
|
789
|
+
|
|
790
|
+
children_nodes = input_name_to_nodes[mask_int32]
|
|
791
|
+
if self.model.find_graph_input(mask_int32):
|
|
792
|
+
attention_nodes = [node for node in children_nodes if node.op_type in ["Attention", "MultiHeadAttention"]]
|
|
793
|
+
self.replace_mask(mask_int32, attention_nodes)
|
|
794
|
+
self.increase_counter("EmbedLayerNormalization(with mask)")
|
|
795
|
+
return
|
|
796
|
+
|
|
797
|
+
if mask_int32 not in output_name_to_node:
|
|
798
|
+
logger.debug("EmbedLayerNormalization will not have mask since %s is not a node output", mask_int32)
|
|
799
|
+
self.increase_counter("EmbedLayerNormalization(no mask)")
|
|
800
|
+
return
|
|
801
|
+
|
|
802
|
+
node = output_name_to_node[mask_int32]
|
|
803
|
+
if node.op_type in ["ReduceSum", "Cast"]:
|
|
804
|
+
attention_nodes = [node for node in children_nodes if node.op_type in ["Attention", "MultiHeadAttention"]]
|
|
805
|
+
if node.op_type == "ReduceSum":
|
|
806
|
+
mask_int32 = node.input[0]
|
|
807
|
+
if len(children_nodes) == len(attention_nodes):
|
|
808
|
+
self.nodes_to_remove.append(node)
|
|
809
|
+
self.replace_mask(mask_int32, attention_nodes)
|
|
810
|
+
self.increase_counter("EmbedLayerNormalization(with mask)")
|