onnxruntime-directml 1.20.0__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxruntime/LICENSE +21 -0
- onnxruntime/Privacy.md +21 -0
- onnxruntime/ThirdPartyNotices.txt +6508 -0
- onnxruntime/__init__.py +78 -0
- onnxruntime/backend/__init__.py +6 -0
- onnxruntime/backend/backend.py +174 -0
- onnxruntime/backend/backend_rep.py +53 -0
- onnxruntime/capi/DirectML.dll +0 -0
- onnxruntime/capi/__init__.py +4 -0
- onnxruntime/capi/_ld_preload.py +7 -0
- onnxruntime/capi/_pybind_state.py +33 -0
- onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
- onnxruntime/capi/onnxruntime.dll +0 -0
- onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
- onnxruntime/capi/onnxruntime_inference_collection.py +1108 -0
- onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
- onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
- onnxruntime/capi/onnxruntime_validation.py +150 -0
- onnxruntime/capi/version_info.py +2 -0
- onnxruntime/datasets/__init__.py +17 -0
- onnxruntime/datasets/logreg_iris.onnx +0 -0
- onnxruntime/datasets/mul_1.onnx +0 -0
- onnxruntime/datasets/sigmoid.onnx +13 -0
- onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
- onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
- onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
- onnxruntime/quantization/__init__.py +16 -0
- onnxruntime/quantization/base_quantizer.py +532 -0
- onnxruntime/quantization/calibrate.py +1245 -0
- onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
- onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
- onnxruntime/quantization/execution_providers/qnn/preprocess.py +307 -0
- onnxruntime/quantization/execution_providers/qnn/quant_config.py +387 -0
- onnxruntime/quantization/fusions/__init__.py +3 -0
- onnxruntime/quantization/fusions/fusion.py +311 -0
- onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
- onnxruntime/quantization/fusions/fusion_layernorm.py +135 -0
- onnxruntime/quantization/matmul_4bits_quantizer.py +1480 -0
- onnxruntime/quantization/matmul_bnb4_quantizer.py +240 -0
- onnxruntime/quantization/onnx_model.py +580 -0
- onnxruntime/quantization/onnx_quantizer.py +1008 -0
- onnxruntime/quantization/operators/__init__.py +2 -0
- onnxruntime/quantization/operators/activation.py +119 -0
- onnxruntime/quantization/operators/argmax.py +18 -0
- onnxruntime/quantization/operators/attention.py +73 -0
- onnxruntime/quantization/operators/base_operator.py +26 -0
- onnxruntime/quantization/operators/binary_op.py +72 -0
- onnxruntime/quantization/operators/concat.py +62 -0
- onnxruntime/quantization/operators/conv.py +258 -0
- onnxruntime/quantization/operators/direct_q8.py +78 -0
- onnxruntime/quantization/operators/embed_layernorm.py +121 -0
- onnxruntime/quantization/operators/gather.py +64 -0
- onnxruntime/quantization/operators/gavgpool.py +62 -0
- onnxruntime/quantization/operators/gemm.py +166 -0
- onnxruntime/quantization/operators/lstm.py +117 -0
- onnxruntime/quantization/operators/matmul.py +231 -0
- onnxruntime/quantization/operators/maxpool.py +34 -0
- onnxruntime/quantization/operators/norm.py +40 -0
- onnxruntime/quantization/operators/pad.py +100 -0
- onnxruntime/quantization/operators/pooling.py +67 -0
- onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
- onnxruntime/quantization/operators/resize.py +34 -0
- onnxruntime/quantization/operators/softmax.py +74 -0
- onnxruntime/quantization/operators/split.py +63 -0
- onnxruntime/quantization/operators/where.py +87 -0
- onnxruntime/quantization/preprocess.py +141 -0
- onnxruntime/quantization/qdq_loss_debug.py +389 -0
- onnxruntime/quantization/qdq_quantizer.py +1187 -0
- onnxruntime/quantization/quant_utils.py +891 -0
- onnxruntime/quantization/quantize.py +748 -0
- onnxruntime/quantization/registry.py +106 -0
- onnxruntime/quantization/shape_inference.py +187 -0
- onnxruntime/quantization/tensor_quant_overrides.py +516 -0
- onnxruntime/tools/__init__.py +10 -0
- onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
- onnxruntime/tools/convert_onnx_models_to_ort.py +377 -0
- onnxruntime/tools/file_utils.py +46 -0
- onnxruntime/tools/logger.py +11 -0
- onnxruntime/tools/make_dynamic_shape_fixed.py +72 -0
- onnxruntime/tools/mobile_helpers/__init__.py +0 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +33 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
- onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
- onnxruntime/tools/mobile_helpers/usability_checker.py +739 -0
- onnxruntime/tools/offline_tuning.py +169 -0
- onnxruntime/tools/onnx_model_utils.py +413 -0
- onnxruntime/tools/onnx_randomizer.py +85 -0
- onnxruntime/tools/onnxruntime_test.py +164 -0
- onnxruntime/tools/optimize_onnx_model.py +55 -0
- onnxruntime/tools/ort_format_model/__init__.py +25 -0
- onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +663 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
- onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
- onnxruntime/tools/ort_format_model/types.py +84 -0
- onnxruntime/tools/ort_format_model/utils.py +62 -0
- onnxruntime/tools/pytorch_export_contrib_ops.py +108 -0
- onnxruntime/tools/pytorch_export_helpers.py +131 -0
- onnxruntime/tools/qdq_helpers/__init__.py +0 -0
- onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
- onnxruntime/tools/reduced_build_config_parser.py +202 -0
- onnxruntime/tools/symbolic_shape_infer.py +3016 -0
- onnxruntime/tools/update_onnx_opset.py +31 -0
- onnxruntime/transformers/__init__.py +8 -0
- onnxruntime/transformers/affinity_helper.py +40 -0
- onnxruntime/transformers/benchmark.py +944 -0
- onnxruntime/transformers/benchmark_helper.py +646 -0
- onnxruntime/transformers/bert_perf_test.py +634 -0
- onnxruntime/transformers/bert_test_data.py +642 -0
- onnxruntime/transformers/compare_bert_results.py +246 -0
- onnxruntime/transformers/constants.py +47 -0
- onnxruntime/transformers/convert_generation.py +3124 -0
- onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
- onnxruntime/transformers/convert_to_packing_mode.py +387 -0
- onnxruntime/transformers/dynamo_onnx_helper.py +104 -0
- onnxruntime/transformers/float16.py +501 -0
- onnxruntime/transformers/fusion_attention.py +1235 -0
- onnxruntime/transformers/fusion_attention_clip.py +257 -0
- onnxruntime/transformers/fusion_attention_sam2.py +534 -0
- onnxruntime/transformers/fusion_attention_unet.py +1304 -0
- onnxruntime/transformers/fusion_attention_vae.py +301 -0
- onnxruntime/transformers/fusion_bart_attention.py +640 -0
- onnxruntime/transformers/fusion_base.py +137 -0
- onnxruntime/transformers/fusion_bias_add.py +58 -0
- onnxruntime/transformers/fusion_biasgelu.py +66 -0
- onnxruntime/transformers/fusion_biassplitgelu.py +111 -0
- onnxruntime/transformers/fusion_conformer_attention.py +143 -0
- onnxruntime/transformers/fusion_embedlayer.py +811 -0
- onnxruntime/transformers/fusion_fastgelu.py +360 -0
- onnxruntime/transformers/fusion_gelu.py +259 -0
- onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
- onnxruntime/transformers/fusion_gemmfastgelu.py +122 -0
- onnxruntime/transformers/fusion_gpt_attention.py +546 -0
- onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
- onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
- onnxruntime/transformers/fusion_group_norm.py +179 -0
- onnxruntime/transformers/fusion_layernorm.py +465 -0
- onnxruntime/transformers/fusion_nhwc_conv.py +100 -0
- onnxruntime/transformers/fusion_options.py +340 -0
- onnxruntime/transformers/fusion_qordered_attention.py +421 -0
- onnxruntime/transformers/fusion_qordered_gelu.py +119 -0
- onnxruntime/transformers/fusion_qordered_layernorm.py +123 -0
- onnxruntime/transformers/fusion_qordered_matmul.py +217 -0
- onnxruntime/transformers/fusion_quickgelu.py +74 -0
- onnxruntime/transformers/fusion_reshape.py +173 -0
- onnxruntime/transformers/fusion_rotary_attention.py +1592 -0
- onnxruntime/transformers/fusion_shape.py +110 -0
- onnxruntime/transformers/fusion_simplified_layernorm.py +159 -0
- onnxruntime/transformers/fusion_skip_group_norm.py +255 -0
- onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
- onnxruntime/transformers/fusion_transpose.py +168 -0
- onnxruntime/transformers/fusion_utils.py +307 -0
- onnxruntime/transformers/huggingface_models.py +167 -0
- onnxruntime/transformers/import_utils.py +20 -0
- onnxruntime/transformers/io_binding_helper.py +442 -0
- onnxruntime/transformers/large_model_exporter.py +395 -0
- onnxruntime/transformers/machine_info.py +221 -0
- onnxruntime/transformers/metrics.py +164 -0
- onnxruntime/transformers/models/bart/__init__.py +12 -0
- onnxruntime/transformers/models/bart/export.py +98 -0
- onnxruntime/transformers/models/bert/__init__.py +12 -0
- onnxruntime/transformers/models/bert/eval_squad.py +329 -0
- onnxruntime/transformers/models/gpt2/__init__.py +12 -0
- onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
- onnxruntime/transformers/models/gpt2/convert_to_onnx.py +561 -0
- onnxruntime/transformers/models/gpt2/gpt2_helper.py +1032 -0
- onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
- onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
- onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
- onnxruntime/transformers/models/llama/__init__.py +12 -0
- onnxruntime/transformers/models/llama/benchmark.py +703 -0
- onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
- onnxruntime/transformers/models/llama/benchmark_e2e.py +606 -0
- onnxruntime/transformers/models/llama/convert_to_onnx.py +1027 -0
- onnxruntime/transformers/models/llama/dist_settings.py +57 -0
- onnxruntime/transformers/models/llama/llama_inputs.py +503 -0
- onnxruntime/transformers/models/llama/llama_parity.py +309 -0
- onnxruntime/transformers/models/llama/llama_torch.py +47 -0
- onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
- onnxruntime/transformers/models/longformer/__init__.py +12 -0
- onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
- onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
- onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
- onnxruntime/transformers/models/longformer/longformer_helper.py +77 -0
- onnxruntime/transformers/models/phi2/__init__.py +12 -0
- onnxruntime/transformers/models/phi2/convert_to_onnx.py +576 -0
- onnxruntime/transformers/models/phi2/inference_example.py +414 -0
- onnxruntime/transformers/models/sam2/__init__.py +12 -0
- onnxruntime/transformers/models/sam2/benchmark_sam2.py +625 -0
- onnxruntime/transformers/models/sam2/convert_to_onnx.py +260 -0
- onnxruntime/transformers/models/sam2/image_decoder.py +273 -0
- onnxruntime/transformers/models/sam2/image_encoder.py +186 -0
- onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
- onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
- onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
- onnxruntime/transformers/models/sam2/sam2_demo.py +322 -0
- onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +280 -0
- onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
- onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark.py +1429 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +102 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +268 -0
- onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1319 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1181 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder.py +296 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +388 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
- onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +350 -0
- onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
- onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
- onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
- onnxruntime/transformers/models/t5/__init__.py +12 -0
- onnxruntime/transformers/models/t5/convert_to_onnx.py +278 -0
- onnxruntime/transformers/models/t5/past_helper.py +150 -0
- onnxruntime/transformers/models/t5/t5_decoder.py +438 -0
- onnxruntime/transformers/models/t5/t5_encoder.py +171 -0
- onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +299 -0
- onnxruntime/transformers/models/t5/t5_helper.py +272 -0
- onnxruntime/transformers/models/whisper/__init__.py +12 -0
- onnxruntime/transformers/models/whisper/benchmark.py +610 -0
- onnxruntime/transformers/models/whisper/benchmark_all.py +528 -0
- onnxruntime/transformers/models/whisper/convert_to_onnx.py +536 -0
- onnxruntime/transformers/models/whisper/whisper_chain.py +329 -0
- onnxruntime/transformers/models/whisper/whisper_decoder.py +402 -0
- onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
- onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +306 -0
- onnxruntime/transformers/models/whisper/whisper_helper.py +524 -0
- onnxruntime/transformers/models/whisper/whisper_openai_helper.py +84 -0
- onnxruntime/transformers/onnx_exporter.py +717 -0
- onnxruntime/transformers/onnx_model.py +1569 -0
- onnxruntime/transformers/onnx_model_bart.py +142 -0
- onnxruntime/transformers/onnx_model_bert.py +481 -0
- onnxruntime/transformers/onnx_model_bert_keras.py +475 -0
- onnxruntime/transformers/onnx_model_bert_tf.py +589 -0
- onnxruntime/transformers/onnx_model_clip.py +40 -0
- onnxruntime/transformers/onnx_model_conformer.py +33 -0
- onnxruntime/transformers/onnx_model_gpt2.py +101 -0
- onnxruntime/transformers/onnx_model_phi.py +930 -0
- onnxruntime/transformers/onnx_model_sam2.py +138 -0
- onnxruntime/transformers/onnx_model_t5.py +791 -0
- onnxruntime/transformers/onnx_model_tnlr.py +227 -0
- onnxruntime/transformers/onnx_model_unet.py +259 -0
- onnxruntime/transformers/onnx_model_vae.py +43 -0
- onnxruntime/transformers/onnx_utils.py +55 -0
- onnxruntime/transformers/optimizer.py +612 -0
- onnxruntime/transformers/profiler.py +725 -0
- onnxruntime/transformers/quantize_helper.py +76 -0
- onnxruntime/transformers/shape_infer_helper.py +122 -0
- onnxruntime/transformers/shape_optimizer.py +401 -0
- onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
- onnxruntime_directml-1.20.0.dist-info/METADATA +187 -0
- onnxruntime_directml-1.20.0.dist-info/RECORD +305 -0
- onnxruntime_directml-1.20.0.dist-info/WHEEL +5 -0
- onnxruntime_directml-1.20.0.dist-info/entry_points.txt +2 -0
- onnxruntime_directml-1.20.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,421 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License.
|
|
4
|
+
# --------------------------------------------------------------------------
|
|
5
|
+
|
|
6
|
+
from logging import getLogger
|
|
7
|
+
from typing import Tuple
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
10
|
+
from fusion_attention import AttentionMask
|
|
11
|
+
from fusion_base import Fusion
|
|
12
|
+
from fusion_utils import FusionUtils, NumpyHelper
|
|
13
|
+
from onnx import NodeProto, helper
|
|
14
|
+
from onnx_model import OnnxModel
|
|
15
|
+
|
|
16
|
+
logger = getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class FusionQOrderedAttention(Fusion):
|
|
20
|
+
def __init__(
|
|
21
|
+
self,
|
|
22
|
+
model: OnnxModel,
|
|
23
|
+
hidden_size: int,
|
|
24
|
+
num_heads: int,
|
|
25
|
+
attention_mask: AttentionMask,
|
|
26
|
+
):
|
|
27
|
+
self.hidden_size = hidden_size
|
|
28
|
+
self.num_heads = num_heads
|
|
29
|
+
self.attention_mask = attention_mask
|
|
30
|
+
|
|
31
|
+
super().__init__(model, "QOrderedAttention", "QOrderedLayerNormalization")
|
|
32
|
+
|
|
33
|
+
def get_num_heads_and_hidden_size(self, reshape_q: NodeProto) -> Tuple[int, int]:
|
|
34
|
+
"""Detect num_heads and hidden_size from a reshape node.
|
|
35
|
+
Args:
|
|
36
|
+
reshape_q (NodeProto): reshape node for Q
|
|
37
|
+
Returns:
|
|
38
|
+
Tuple[int, int]: num_heads and hidden_size
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
# we assume that reshape fusion has done, so the shape is a tensor like [0, 0, num_heads, head_size]
|
|
42
|
+
q_shape = self.model.get_initializer(reshape_q.input[1])
|
|
43
|
+
if q_shape is None:
|
|
44
|
+
logger.debug(f"{reshape_q.input[1]} is not initializer.")
|
|
45
|
+
|
|
46
|
+
# Check if the second input to Reshape flows through a Constant node
|
|
47
|
+
# TODO: Investigate why FusionAttention doesn't have such logic
|
|
48
|
+
constant_node = self.model.match_parent_path(reshape_q, ["Constant"], [1])
|
|
49
|
+
|
|
50
|
+
if constant_node is None:
|
|
51
|
+
return self.num_heads, self.hidden_size # Fall back to user specified value
|
|
52
|
+
else:
|
|
53
|
+
constant_node = constant_node[0]
|
|
54
|
+
|
|
55
|
+
if len(constant_node.attribute) != 1:
|
|
56
|
+
return self.num_heads, self.hidden_size # Fall back to user specified value
|
|
57
|
+
|
|
58
|
+
# This is assuming it is a Tensor attribute (this is a safe assumption)
|
|
59
|
+
q_shape = constant_node.attribute[0].t
|
|
60
|
+
|
|
61
|
+
q_shape_value = NumpyHelper.to_array(q_shape)
|
|
62
|
+
if len(q_shape_value) != 4 or (q_shape_value[2] <= 0 or q_shape_value[3] <= 0):
|
|
63
|
+
logger.debug(f"q_shape_value={q_shape_value}. Expected value are like [0, 0, num_heads, head_size].")
|
|
64
|
+
return self.num_heads, self.hidden_size # Fall back to user specified value
|
|
65
|
+
|
|
66
|
+
num_heads = q_shape_value[2]
|
|
67
|
+
head_size = q_shape_value[3]
|
|
68
|
+
hidden_size = num_heads * head_size
|
|
69
|
+
|
|
70
|
+
if self.num_heads > 0 and num_heads != self.num_heads:
|
|
71
|
+
if self.num_heads_warning:
|
|
72
|
+
logger.warning(f"--num_heads is {self.num_heads}. Detected value is {num_heads}. Using detected value.")
|
|
73
|
+
self.num_heads_warning = False # Do not show the warning more than once
|
|
74
|
+
|
|
75
|
+
if self.hidden_size > 0 and hidden_size != self.hidden_size:
|
|
76
|
+
if self.hidden_size_warning:
|
|
77
|
+
logger.warning(
|
|
78
|
+
f"--hidden_size is {self.hidden_size}. Detected value is {hidden_size}. Using detected value."
|
|
79
|
+
)
|
|
80
|
+
self.hidden_size_warning = False # Do not show the warning more than once
|
|
81
|
+
|
|
82
|
+
return num_heads, hidden_size
|
|
83
|
+
|
|
84
|
+
def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
|
|
85
|
+
add_before_layernorm = self.model.match_parent_path(
|
|
86
|
+
normalize_node,
|
|
87
|
+
["QuantizeLinear", "Add"],
|
|
88
|
+
[0, 0],
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
if add_before_layernorm is not None:
|
|
92
|
+
start_node = add_before_layernorm[-1]
|
|
93
|
+
else:
|
|
94
|
+
return
|
|
95
|
+
|
|
96
|
+
# Input QDQ nodes
|
|
97
|
+
dequantize_input = self.model.match_parent_path(
|
|
98
|
+
start_node,
|
|
99
|
+
["DequantizeLinear"],
|
|
100
|
+
[None],
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
if dequantize_input is None:
|
|
104
|
+
logger.debug("fuse_qordered_attention: failed to match input qdq nodes path")
|
|
105
|
+
return
|
|
106
|
+
|
|
107
|
+
dequantize_input = dequantize_input[-1]
|
|
108
|
+
|
|
109
|
+
# QKV nodes
|
|
110
|
+
qkv_nodes = self.model.match_parent_path(
|
|
111
|
+
start_node,
|
|
112
|
+
["Add", "MatMul", "Reshape", "Transpose", "DequantizeLinear", "QuantizeLinear", "MatMul"],
|
|
113
|
+
[None, None, 0, 0, 0, 0, 0],
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
if qkv_nodes is None:
|
|
117
|
+
logger.debug("fuse_qordered_attention: failed to match qkv path")
|
|
118
|
+
return
|
|
119
|
+
|
|
120
|
+
(_, projection_matmul, reshape_qkv, transpose_qkv, dequantize_qkv, quantize_qkv, matmul_qkv) = qkv_nodes
|
|
121
|
+
|
|
122
|
+
# Make sure the Q/DQ has the proper zero points and constant per-tensor scales
|
|
123
|
+
if not FusionUtils.check_qdq_node_for_fusion(quantize_qkv, self.model):
|
|
124
|
+
return
|
|
125
|
+
|
|
126
|
+
if not FusionUtils.check_qdq_node_for_fusion(dequantize_qkv, self.model):
|
|
127
|
+
return
|
|
128
|
+
|
|
129
|
+
# Identify the root input to the Attention node
|
|
130
|
+
other_inputs = []
|
|
131
|
+
for _i, input in enumerate(start_node.input):
|
|
132
|
+
if input not in output_name_to_node:
|
|
133
|
+
continue
|
|
134
|
+
|
|
135
|
+
if input == qkv_nodes[0].output[0]:
|
|
136
|
+
continue
|
|
137
|
+
|
|
138
|
+
other_inputs.append(input)
|
|
139
|
+
|
|
140
|
+
if len(other_inputs) != 1:
|
|
141
|
+
return
|
|
142
|
+
|
|
143
|
+
root_input = other_inputs[0]
|
|
144
|
+
|
|
145
|
+
# V nodes
|
|
146
|
+
v_nodes = self.model.match_parent_path(
|
|
147
|
+
matmul_qkv,
|
|
148
|
+
["Transpose", "Reshape", "DequantizeLinear", "QuantizeLinear", "Add", "MatMul"],
|
|
149
|
+
[1, 0, 0, 0, 0, None],
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
if v_nodes is None:
|
|
153
|
+
logger.debug("fuse_qordered_attention: failed to match v path")
|
|
154
|
+
return
|
|
155
|
+
|
|
156
|
+
(_, _, dequantize_v, quantize_v, add_v, matmul_v) = v_nodes
|
|
157
|
+
|
|
158
|
+
# Make sure the Q/DQ has the proper zero points and constant per-tensor scales
|
|
159
|
+
if not FusionUtils.check_qdq_node_for_fusion(quantize_v, self.model):
|
|
160
|
+
return
|
|
161
|
+
|
|
162
|
+
if not FusionUtils.check_qdq_node_for_fusion(dequantize_v, self.model):
|
|
163
|
+
return
|
|
164
|
+
|
|
165
|
+
# V MatMul weight
|
|
166
|
+
dequantize_v_matmul_weight = self.model.match_parent_path(matmul_v, ["DequantizeLinear"], [1])
|
|
167
|
+
|
|
168
|
+
if dequantize_v_matmul_weight is None:
|
|
169
|
+
logger.debug("fuse_qordered_attention: failed to match v path")
|
|
170
|
+
return
|
|
171
|
+
|
|
172
|
+
dequantize_v_matmul_weight = dequantize_v_matmul_weight[0]
|
|
173
|
+
|
|
174
|
+
if self.model.get_constant_value(dequantize_v_matmul_weight.input[0]) is None:
|
|
175
|
+
return
|
|
176
|
+
|
|
177
|
+
# Make sure the upstream DequantizeLinear-1 has the proper zero points and scales
|
|
178
|
+
# Per-channel scales are supported for weights alone
|
|
179
|
+
if not FusionUtils.check_qdq_node_for_fusion(dequantize_v_matmul_weight, self.model, False):
|
|
180
|
+
return
|
|
181
|
+
|
|
182
|
+
# QK nodes
|
|
183
|
+
qk_nodes = self.model.match_parent_path(
|
|
184
|
+
matmul_qkv,
|
|
185
|
+
[
|
|
186
|
+
"DequantizeLinear",
|
|
187
|
+
"QuantizeLinear",
|
|
188
|
+
"Softmax",
|
|
189
|
+
"Add",
|
|
190
|
+
"Div",
|
|
191
|
+
"DequantizeLinear",
|
|
192
|
+
"QuantizeLinear",
|
|
193
|
+
"MatMul",
|
|
194
|
+
],
|
|
195
|
+
[0, 0, 0, 0, None, 0, 0, 0],
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
if qk_nodes is None:
|
|
199
|
+
logger.debug("fuse_qordered_attention: failed to match qk path")
|
|
200
|
+
return
|
|
201
|
+
|
|
202
|
+
(
|
|
203
|
+
dequantize_qk_softmax,
|
|
204
|
+
quantize_qk_softmax,
|
|
205
|
+
softmax_qk,
|
|
206
|
+
add_qk,
|
|
207
|
+
div_qk,
|
|
208
|
+
dequantize_qk,
|
|
209
|
+
quantize_qk,
|
|
210
|
+
matmul_qk,
|
|
211
|
+
) = qk_nodes
|
|
212
|
+
|
|
213
|
+
# Make sure the Q/DQ has the proper zero points and constant per-tensor scales
|
|
214
|
+
if not FusionUtils.check_qdq_node_for_fusion(quantize_qk_softmax, self.model):
|
|
215
|
+
return
|
|
216
|
+
|
|
217
|
+
if not FusionUtils.check_qdq_node_for_fusion(dequantize_qk_softmax, self.model):
|
|
218
|
+
return
|
|
219
|
+
|
|
220
|
+
if not FusionUtils.check_qdq_node_for_fusion(quantize_qk, self.model):
|
|
221
|
+
return
|
|
222
|
+
|
|
223
|
+
if not FusionUtils.check_qdq_node_for_fusion(dequantize_qk, self.model):
|
|
224
|
+
return
|
|
225
|
+
|
|
226
|
+
# Q nodes
|
|
227
|
+
q_nodes = self.model.match_parent_path(
|
|
228
|
+
matmul_qk,
|
|
229
|
+
["Transpose", "Reshape", "DequantizeLinear", "QuantizeLinear", "Add", "MatMul"],
|
|
230
|
+
[0, 0, 0, 0, 0, None],
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
if q_nodes is None:
|
|
234
|
+
logger.debug("fuse_qordered_attention: failed to match q path")
|
|
235
|
+
return
|
|
236
|
+
|
|
237
|
+
(_, reshape_q, dequantize_q, quantize_q, add_q, matmul_q) = q_nodes
|
|
238
|
+
|
|
239
|
+
# Make sure the Q/DQ has the proper zero points and constant per-tensor scales
|
|
240
|
+
if not FusionUtils.check_qdq_node_for_fusion(quantize_q, self.model):
|
|
241
|
+
return
|
|
242
|
+
|
|
243
|
+
if not FusionUtils.check_qdq_node_for_fusion(dequantize_q, self.model):
|
|
244
|
+
return
|
|
245
|
+
|
|
246
|
+
# Q MatMul weight
|
|
247
|
+
dequantize_q_matmul_weight = self.model.match_parent_path(matmul_q, ["DequantizeLinear"], [1])
|
|
248
|
+
|
|
249
|
+
if dequantize_q_matmul_weight is None:
|
|
250
|
+
logger.debug("fuse_qordered_attention: failed to match q path")
|
|
251
|
+
return
|
|
252
|
+
|
|
253
|
+
dequantize_q_matmul_weight = dequantize_q_matmul_weight[0]
|
|
254
|
+
|
|
255
|
+
if self.model.get_constant_value(dequantize_q_matmul_weight.input[0]) is None:
|
|
256
|
+
return
|
|
257
|
+
|
|
258
|
+
# Make sure the upstream DequantizeLinear-1 has the proper zero points and scales
|
|
259
|
+
# Per-channel scales are supported for weights alone
|
|
260
|
+
if not FusionUtils.check_qdq_node_for_fusion(dequantize_q_matmul_weight, self.model, False):
|
|
261
|
+
return
|
|
262
|
+
|
|
263
|
+
# K nodes
|
|
264
|
+
k_nodes = self.model.match_parent_path(
|
|
265
|
+
matmul_qk,
|
|
266
|
+
["Transpose", "Reshape", "DequantizeLinear", "QuantizeLinear", "Add", "MatMul"],
|
|
267
|
+
[1, 0, 0, 0, 0, None],
|
|
268
|
+
)
|
|
269
|
+
|
|
270
|
+
if k_nodes is None:
|
|
271
|
+
logger.debug("fuse_qordered_attention: failed to match k path")
|
|
272
|
+
return
|
|
273
|
+
|
|
274
|
+
(_, _, dequantize_k, quantize_k, add_k, matmul_k) = k_nodes
|
|
275
|
+
|
|
276
|
+
# Make sure the Q/DQ has the proper zero points and constant per-tensor scales
|
|
277
|
+
if not FusionUtils.check_qdq_node_for_fusion(quantize_k, self.model):
|
|
278
|
+
return
|
|
279
|
+
|
|
280
|
+
if not FusionUtils.check_qdq_node_for_fusion(dequantize_k, self.model):
|
|
281
|
+
return
|
|
282
|
+
|
|
283
|
+
# K MatMul weight
|
|
284
|
+
dequantize_k_matmul_weight = self.model.match_parent_path(matmul_k, ["DequantizeLinear"], [1])
|
|
285
|
+
|
|
286
|
+
if dequantize_k_matmul_weight is None:
|
|
287
|
+
logger.debug("fuse_qordered_attention: failed to match k path")
|
|
288
|
+
return
|
|
289
|
+
|
|
290
|
+
dequantize_k_matmul_weight = dequantize_k_matmul_weight[0]
|
|
291
|
+
|
|
292
|
+
if self.model.get_constant_value(dequantize_k_matmul_weight.input[0]) is None:
|
|
293
|
+
return
|
|
294
|
+
|
|
295
|
+
# Make sure the upstream DequantizeLinear-1 has the proper zero points and scales
|
|
296
|
+
# Per-channel scales are supported for weights alone
|
|
297
|
+
if not FusionUtils.check_qdq_node_for_fusion(dequantize_k_matmul_weight, self.model, False):
|
|
298
|
+
return
|
|
299
|
+
|
|
300
|
+
# Mask nodes
|
|
301
|
+
mask_nodes = self.model.match_parent_path(
|
|
302
|
+
add_qk, ["Mul", "Sub", "Cast", "Unsqueeze", "Unsqueeze"], [None, 0, 1, 0, 0]
|
|
303
|
+
)
|
|
304
|
+
|
|
305
|
+
if mask_nodes is None:
|
|
306
|
+
logger.debug("fuse_qordered_attention: failed to match mask_nodes path")
|
|
307
|
+
return
|
|
308
|
+
|
|
309
|
+
# Ascertain `qkv_hidden_sizes` attribute value
|
|
310
|
+
q_weight = self.model.get_initializer(dequantize_q_matmul_weight.input[0])
|
|
311
|
+
k_weight = self.model.get_initializer(dequantize_k_matmul_weight.input[0])
|
|
312
|
+
v_weight = self.model.get_initializer(dequantize_v_matmul_weight.input[0])
|
|
313
|
+
|
|
314
|
+
qw = NumpyHelper.to_array(q_weight)
|
|
315
|
+
kw = NumpyHelper.to_array(k_weight)
|
|
316
|
+
vw = NumpyHelper.to_array(v_weight)
|
|
317
|
+
|
|
318
|
+
qw_out_size = np.prod(qw.shape[1:])
|
|
319
|
+
kw_out_size = np.prod(kw.shape[1:])
|
|
320
|
+
vw_out_size = np.prod(vw.shape[1:])
|
|
321
|
+
|
|
322
|
+
# Form QOrderedAttention node
|
|
323
|
+
if matmul_v.input[0] == root_input and matmul_q.input[0] == root_input and matmul_k.input[0] == root_input:
|
|
324
|
+
mask_index = self.attention_mask.process_mask(mask_nodes[-1].input[0])
|
|
325
|
+
|
|
326
|
+
# Ascertain `num_heads` and `hidden_size`
|
|
327
|
+
num_heads, hidden_size = self.get_num_heads_and_hidden_size(reshape_q)
|
|
328
|
+
|
|
329
|
+
# Formulate the inputs
|
|
330
|
+
# Actual quantized input
|
|
331
|
+
attention_inputs = [dequantize_input.input[0]]
|
|
332
|
+
attention_inputs.append(dequantize_input.input[1])
|
|
333
|
+
|
|
334
|
+
attention_inputs.append(dequantize_q.input[1])
|
|
335
|
+
attention_inputs.append(dequantize_k.input[1])
|
|
336
|
+
attention_inputs.append(dequantize_v.input[1])
|
|
337
|
+
|
|
338
|
+
attention_inputs.append(dequantize_q_matmul_weight.input[0])
|
|
339
|
+
attention_inputs.append(dequantize_k_matmul_weight.input[0])
|
|
340
|
+
attention_inputs.append(dequantize_v_matmul_weight.input[0])
|
|
341
|
+
|
|
342
|
+
attention_inputs.append(dequantize_q_matmul_weight.input[1])
|
|
343
|
+
attention_inputs.append(dequantize_k_matmul_weight.input[1])
|
|
344
|
+
attention_inputs.append(dequantize_v_matmul_weight.input[1])
|
|
345
|
+
|
|
346
|
+
if self.model.get_initializer(add_q.input[0]):
|
|
347
|
+
attention_inputs.append(add_q.input[0])
|
|
348
|
+
else: # second input is the constant bias
|
|
349
|
+
attention_inputs.append(add_q.input[1])
|
|
350
|
+
|
|
351
|
+
if self.model.get_initializer(add_k.input[0]):
|
|
352
|
+
attention_inputs.append(add_k.input[0])
|
|
353
|
+
else: # second input is the constant bias
|
|
354
|
+
attention_inputs.append(add_k.input[1])
|
|
355
|
+
|
|
356
|
+
if self.model.get_initializer(add_v.input[0]):
|
|
357
|
+
attention_inputs.append(add_v.input[0])
|
|
358
|
+
else: # second input is the constant bias
|
|
359
|
+
attention_inputs.append(add_v.input[1])
|
|
360
|
+
|
|
361
|
+
attention_inputs.append(quantize_qk.input[1])
|
|
362
|
+
attention_inputs.append(quantize_qk_softmax.input[1])
|
|
363
|
+
attention_inputs.append(dequantize_qkv.input[1])
|
|
364
|
+
|
|
365
|
+
# Mask input
|
|
366
|
+
if mask_index is not None:
|
|
367
|
+
attention_inputs.append(mask_index)
|
|
368
|
+
else:
|
|
369
|
+
attention_inputs.append("")
|
|
370
|
+
|
|
371
|
+
# The MatMul weight 'B' and 'bias' need some post-processing
|
|
372
|
+
# Transpose weight 'B' from order ROW to order COL
|
|
373
|
+
# This offline transpose is needed only while using the CUDA EP
|
|
374
|
+
# TODO: Make this fusion logic EP-agnostic ?
|
|
375
|
+
q_weight_tensor = self.model.get_initializer(dequantize_q_matmul_weight.input[0])
|
|
376
|
+
FusionUtils.transpose_2d_int8_tensor(q_weight_tensor)
|
|
377
|
+
|
|
378
|
+
k_weight_tensor = self.model.get_initializer(dequantize_k_matmul_weight.input[0])
|
|
379
|
+
FusionUtils.transpose_2d_int8_tensor(k_weight_tensor)
|
|
380
|
+
|
|
381
|
+
v_weight_tensor = self.model.get_initializer(dequantize_v_matmul_weight.input[0])
|
|
382
|
+
FusionUtils.transpose_2d_int8_tensor(v_weight_tensor)
|
|
383
|
+
|
|
384
|
+
# Name and create Attention node
|
|
385
|
+
attention_node_name = self.model.create_node_name("QOrderedAttention")
|
|
386
|
+
|
|
387
|
+
attention_node = helper.make_node(
|
|
388
|
+
"QOrderedAttention",
|
|
389
|
+
inputs=attention_inputs,
|
|
390
|
+
outputs=[reshape_qkv.output[0]],
|
|
391
|
+
name=attention_node_name,
|
|
392
|
+
)
|
|
393
|
+
|
|
394
|
+
self.model.replace_node_input(dequantize_qkv, dequantize_qkv.input[0], attention_node.output[0])
|
|
395
|
+
self.model.replace_node_input(projection_matmul, projection_matmul.input[0], dequantize_qkv.output[0])
|
|
396
|
+
|
|
397
|
+
attention_node.attribute.extend([helper.make_attribute("num_heads", num_heads)])
|
|
398
|
+
attention_node.attribute.extend([helper.make_attribute("order_input", 1)])
|
|
399
|
+
attention_node.attribute.extend([helper.make_attribute("order_weight", 0)])
|
|
400
|
+
attention_node.attribute.extend([helper.make_attribute("order_output", 1)])
|
|
401
|
+
attention_node.attribute.extend(
|
|
402
|
+
[helper.make_attribute("qkv_hidden_sizes", [qw_out_size, kw_out_size, vw_out_size])]
|
|
403
|
+
)
|
|
404
|
+
|
|
405
|
+
attention_node.domain = "com.microsoft"
|
|
406
|
+
|
|
407
|
+
self.nodes_to_add.append(attention_node)
|
|
408
|
+
self.node_name_to_graph_name[attention_node.name] = self.this_graph_name
|
|
409
|
+
|
|
410
|
+
self.nodes_to_remove.extend([reshape_qkv, transpose_qkv, quantize_qkv, matmul_qkv])
|
|
411
|
+
self.nodes_to_remove.extend(qk_nodes)
|
|
412
|
+
self.nodes_to_remove.extend(q_nodes)
|
|
413
|
+
self.nodes_to_remove.extend(k_nodes)
|
|
414
|
+
self.nodes_to_remove.extend(v_nodes)
|
|
415
|
+
self.nodes_to_remove.extend(
|
|
416
|
+
[dequantize_q_matmul_weight, dequantize_k_matmul_weight, dequantize_v_matmul_weight]
|
|
417
|
+
)
|
|
418
|
+
|
|
419
|
+
# Use prune graph to remove mask nodes since they are shared by all attention nodes.
|
|
420
|
+
# self.nodes_to_remove.extend(mask_nodes)
|
|
421
|
+
self.prune_graph = True
|
|
@@ -0,0 +1,119 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License.
|
|
4
|
+
# --------------------------------------------------------------------------
|
|
5
|
+
|
|
6
|
+
from logging import getLogger
|
|
7
|
+
from typing import Dict
|
|
8
|
+
|
|
9
|
+
from fusion_base import Fusion
|
|
10
|
+
from fusion_utils import FusionUtils
|
|
11
|
+
from onnx import helper
|
|
12
|
+
from onnx_model import OnnxModel
|
|
13
|
+
|
|
14
|
+
logger = getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class FusionQOrderedGelu(Fusion):
|
|
18
|
+
def __init__(self, model: OnnxModel):
|
|
19
|
+
super().__init__(model, "QOrderedGelu", ["Gelu", "FastGelu"])
|
|
20
|
+
|
|
21
|
+
def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict):
|
|
22
|
+
"""
|
|
23
|
+
INPUT PATTERN
|
|
24
|
+
Fuse (quantized) Gelu subgraph into one node QOrderedGelu:
|
|
25
|
+
-> quantized input -> DQ -> Gelu -> Q ->
|
|
26
|
+
|
|
27
|
+
(or)
|
|
28
|
+
|
|
29
|
+
-> quantized input -> DQ -> FastGelu -> Q ->
|
|
30
|
+
|
|
31
|
+
OUTPUT PATTERN
|
|
32
|
+
-> QOrderedGelu ->
|
|
33
|
+
"""
|
|
34
|
+
gelu_children = self.model.get_children(node, input_name_to_nodes)
|
|
35
|
+
|
|
36
|
+
# Should only have 1 child - QuantizeLinear (or)
|
|
37
|
+
# Should have 2 children - QuantizeLinear + Shape
|
|
38
|
+
if not (
|
|
39
|
+
(len(gelu_children) == 1 and gelu_children[0].op_type == "QuantizeLinear")
|
|
40
|
+
or (
|
|
41
|
+
len(gelu_children) == 2
|
|
42
|
+
and gelu_children[0].op_type == "QuantizeLinear"
|
|
43
|
+
and gelu_children[1].op_type == "Shape"
|
|
44
|
+
)
|
|
45
|
+
):
|
|
46
|
+
return
|
|
47
|
+
|
|
48
|
+
downstream_quantize_node = gelu_children[0]
|
|
49
|
+
downstream_shape_node = None
|
|
50
|
+
|
|
51
|
+
if len(gelu_children) == 2:
|
|
52
|
+
downstream_shape_node = gelu_children[1]
|
|
53
|
+
|
|
54
|
+
if not FusionUtils.check_qdq_node_for_fusion(downstream_quantize_node, self.model):
|
|
55
|
+
return
|
|
56
|
+
|
|
57
|
+
# The first input to Gelu should flow through a DequantizeLinear node
|
|
58
|
+
first_path_id, first_input_parent_nodes, _ = self.model.match_parent_paths(
|
|
59
|
+
node,
|
|
60
|
+
[(["DequantizeLinear"], [0])],
|
|
61
|
+
output_name_to_node,
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
if first_path_id < 0:
|
|
65
|
+
return
|
|
66
|
+
|
|
67
|
+
upstream_dequantize_node = first_input_parent_nodes[0]
|
|
68
|
+
|
|
69
|
+
if not FusionUtils.check_qdq_node_for_fusion(upstream_dequantize_node, self.model):
|
|
70
|
+
return
|
|
71
|
+
|
|
72
|
+
# Fusion logic
|
|
73
|
+
subgraph_nodes = [node] # Gelu/FastGelu
|
|
74
|
+
subgraph_nodes.extend([downstream_quantize_node, upstream_dequantize_node]) # Relevant Q, DQ nodes
|
|
75
|
+
|
|
76
|
+
if not self.model.is_safe_to_fuse_nodes(
|
|
77
|
+
subgraph_nodes,
|
|
78
|
+
(
|
|
79
|
+
[node.output[0], downstream_quantize_node.output[0]]
|
|
80
|
+
if downstream_shape_node is not None
|
|
81
|
+
else downstream_quantize_node.output
|
|
82
|
+
),
|
|
83
|
+
input_name_to_nodes,
|
|
84
|
+
output_name_to_node,
|
|
85
|
+
):
|
|
86
|
+
logger.debug("It is not safe to fuse QOrderedGelu node. Skip")
|
|
87
|
+
return
|
|
88
|
+
|
|
89
|
+
self.nodes_to_remove.extend(subgraph_nodes)
|
|
90
|
+
|
|
91
|
+
ordered_gelu_node = helper.make_node(
|
|
92
|
+
"QOrderedGelu",
|
|
93
|
+
inputs=[
|
|
94
|
+
upstream_dequantize_node.input[0],
|
|
95
|
+
upstream_dequantize_node.input[1],
|
|
96
|
+
downstream_quantize_node.input[1],
|
|
97
|
+
],
|
|
98
|
+
outputs=[downstream_quantize_node.output[0]],
|
|
99
|
+
name=self.model.create_node_name("QOrderedGelu", name_prefix="QOrderedGelu"),
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
# Arrange the downstream Shape's input to be fed from the
|
|
103
|
+
# downstream QuantizeLinear node, so that fusion will
|
|
104
|
+
# be deemed safe
|
|
105
|
+
if downstream_shape_node is not None:
|
|
106
|
+
self.model.replace_node_input(
|
|
107
|
+
downstream_shape_node, downstream_shape_node.input[0], downstream_quantize_node.output[0]
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
# TODO: We only support CuBlasLt order ORDER_ROW for now.
|
|
111
|
+
# Once we start supporting other data ordering format(s), we
|
|
112
|
+
# will support user configuring the data ordering for the op.
|
|
113
|
+
ordered_gelu_node.attribute.extend([helper.make_attribute("order_X", 1)])
|
|
114
|
+
ordered_gelu_node.attribute.extend([helper.make_attribute("order_Y", 1)])
|
|
115
|
+
|
|
116
|
+
ordered_gelu_node.domain = "com.microsoft"
|
|
117
|
+
|
|
118
|
+
self.nodes_to_add.append(ordered_gelu_node)
|
|
119
|
+
self.node_name_to_graph_name[ordered_gelu_node.name] = self.this_graph_name
|
|
@@ -0,0 +1,123 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License.
|
|
4
|
+
# --------------------------------------------------------------------------
|
|
5
|
+
from logging import getLogger
|
|
6
|
+
from typing import Dict
|
|
7
|
+
|
|
8
|
+
from fusion_base import Fusion
|
|
9
|
+
from fusion_utils import FusionUtils
|
|
10
|
+
from onnx import helper
|
|
11
|
+
from onnx_model import OnnxModel
|
|
12
|
+
|
|
13
|
+
logger = getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class FusionQOrderedLayerNormalization(Fusion):
|
|
17
|
+
def __init__(self, model: OnnxModel):
|
|
18
|
+
super().__init__(model, "QOrderedLayerNormalization", "LayerNormalization")
|
|
19
|
+
|
|
20
|
+
def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict):
|
|
21
|
+
"""
|
|
22
|
+
Fuse (quantized) Layer Normalization subgraph into one node QOrderedLayerNormalization:
|
|
23
|
+
quantized input -> DQ
|
|
24
|
+
|
|
|
25
|
+
|
|
|
26
|
+
(other inputs)-> LayerNormalization --> Q -->
|
|
27
|
+
|
|
28
|
+
should become
|
|
29
|
+
|
|
30
|
+
(quantized input + other inputs)-> QOrderedLayerNormalization --> Q -->
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
children = self.model.get_children(node, input_name_to_nodes)
|
|
34
|
+
|
|
35
|
+
# Should only have 1 child - QuantizeLinear (or)
|
|
36
|
+
# Should have 2 children - QuantizeLinear + Shape
|
|
37
|
+
if not (
|
|
38
|
+
(len(children) == 1 and children[0].op_type == "QuantizeLinear")
|
|
39
|
+
or (len(children) == 2 and children[0].op_type == "QuantizeLinear" and children[1].op_type == "Shape")
|
|
40
|
+
):
|
|
41
|
+
return
|
|
42
|
+
|
|
43
|
+
downstream_quantize_node = children[0]
|
|
44
|
+
downstream_shape_node = None
|
|
45
|
+
|
|
46
|
+
if len(children) == 2:
|
|
47
|
+
downstream_shape_node = children[1]
|
|
48
|
+
|
|
49
|
+
if not FusionUtils.check_qdq_node_for_fusion(downstream_quantize_node, self.model):
|
|
50
|
+
return
|
|
51
|
+
|
|
52
|
+
# The first input to LayerNormalization should flow through a DequantizeLinear node
|
|
53
|
+
first_path_id, first_input_parent_nodes, _ = self.model.match_parent_paths(
|
|
54
|
+
node,
|
|
55
|
+
[(["DequantizeLinear"], [0])],
|
|
56
|
+
output_name_to_node,
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
if first_path_id < 0:
|
|
60
|
+
return
|
|
61
|
+
|
|
62
|
+
upstream_dequantize_node = first_input_parent_nodes[0]
|
|
63
|
+
|
|
64
|
+
if not FusionUtils.check_qdq_node_for_fusion(upstream_dequantize_node, self.model):
|
|
65
|
+
return
|
|
66
|
+
|
|
67
|
+
# Fusion logic
|
|
68
|
+
subgraph_nodes = [node] # LayerNormalization
|
|
69
|
+
subgraph_nodes.extend([downstream_quantize_node]) # Q node after LayerNormalization
|
|
70
|
+
|
|
71
|
+
upstream_dequantize_node_children = self.model.get_children(upstream_dequantize_node, input_name_to_nodes)
|
|
72
|
+
|
|
73
|
+
# In GPT2, the DQ node will be feeding a residual downstream Add and hence,
|
|
74
|
+
# we do not want to remove it
|
|
75
|
+
if len(upstream_dequantize_node_children) == 1:
|
|
76
|
+
subgraph_nodes.extend([upstream_dequantize_node]) # DQ node before LayerNormalization
|
|
77
|
+
|
|
78
|
+
if not self.model.is_safe_to_fuse_nodes(
|
|
79
|
+
subgraph_nodes,
|
|
80
|
+
(
|
|
81
|
+
[node.output[0], downstream_quantize_node.output[0]]
|
|
82
|
+
if downstream_shape_node is not None
|
|
83
|
+
else downstream_quantize_node.output
|
|
84
|
+
),
|
|
85
|
+
input_name_to_nodes,
|
|
86
|
+
output_name_to_node,
|
|
87
|
+
):
|
|
88
|
+
logger.debug("It is not safe to fuse QOrderedLayerNormalization node. Skip")
|
|
89
|
+
return
|
|
90
|
+
|
|
91
|
+
self.nodes_to_remove.extend(subgraph_nodes)
|
|
92
|
+
|
|
93
|
+
normalize_node = helper.make_node(
|
|
94
|
+
"QOrderedLayerNormalization",
|
|
95
|
+
inputs=[
|
|
96
|
+
upstream_dequantize_node.input[0],
|
|
97
|
+
upstream_dequantize_node.input[1],
|
|
98
|
+
node.input[1],
|
|
99
|
+
node.input[2],
|
|
100
|
+
downstream_quantize_node.input[1],
|
|
101
|
+
],
|
|
102
|
+
outputs=[downstream_quantize_node.output[0]],
|
|
103
|
+
name=self.model.create_node_name("QOrderedLayerNormalization", name_prefix="QOrderedLayerNormalization"),
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
# Arrange the downstream Shape's input to be fed from the
|
|
107
|
+
# downstream QuantizeLinear node, so that fusion will
|
|
108
|
+
# be deemed safe
|
|
109
|
+
if downstream_shape_node is not None:
|
|
110
|
+
self.model.replace_node_input(
|
|
111
|
+
downstream_shape_node, downstream_shape_node.input[0], downstream_quantize_node.output[0]
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
# TODO: We only support CuBlasLt order ORDER_ROW for now.
|
|
115
|
+
# Once we start supporting other data ordering format(s), we
|
|
116
|
+
# will support user configuring the data ordering for the op.
|
|
117
|
+
normalize_node.attribute.extend([helper.make_attribute("order_X", 1)])
|
|
118
|
+
normalize_node.attribute.extend([helper.make_attribute("order_Y", 1)])
|
|
119
|
+
|
|
120
|
+
normalize_node.domain = "com.microsoft"
|
|
121
|
+
|
|
122
|
+
self.nodes_to_add.append(normalize_node)
|
|
123
|
+
self.node_name_to_graph_name[normalize_node.name] = self.this_graph_name
|