onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxruntime/LICENSE +21 -0
- onnxruntime/Privacy.md +21 -0
- onnxruntime/ThirdPartyNotices.txt +6121 -0
- onnxruntime/__init__.py +418 -0
- onnxruntime/backend/__init__.py +6 -0
- onnxruntime/backend/backend.py +175 -0
- onnxruntime/backend/backend_rep.py +52 -0
- onnxruntime/capi/DirectML.dll +0 -0
- onnxruntime/capi/__init__.py +4 -0
- onnxruntime/capi/_ld_preload.py +7 -0
- onnxruntime/capi/_pybind_state.py +33 -0
- onnxruntime/capi/build_and_package_info.py +2 -0
- onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
- onnxruntime/capi/onnxruntime.dll +0 -0
- onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
- onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
- onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
- onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
- onnxruntime/capi/onnxruntime_validation.py +154 -0
- onnxruntime/capi/version_info.py +2 -0
- onnxruntime/datasets/__init__.py +18 -0
- onnxruntime/datasets/logreg_iris.onnx +0 -0
- onnxruntime/datasets/mul_1.onnx +0 -0
- onnxruntime/datasets/sigmoid.onnx +13 -0
- onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
- onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
- onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
- onnxruntime/quantization/__init__.py +19 -0
- onnxruntime/quantization/base_quantizer.py +529 -0
- onnxruntime/quantization/calibrate.py +1267 -0
- onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
- onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
- onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
- onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
- onnxruntime/quantization/fusions/__init__.py +4 -0
- onnxruntime/quantization/fusions/fusion.py +311 -0
- onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
- onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
- onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
- onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
- onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
- onnxruntime/quantization/neural_compressor/__init__.py +1 -0
- onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
- onnxruntime/quantization/neural_compressor/util.py +80 -0
- onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
- onnxruntime/quantization/onnx_model.py +600 -0
- onnxruntime/quantization/onnx_quantizer.py +1163 -0
- onnxruntime/quantization/operators/__init__.py +2 -0
- onnxruntime/quantization/operators/activation.py +119 -0
- onnxruntime/quantization/operators/argmax.py +18 -0
- onnxruntime/quantization/operators/attention.py +73 -0
- onnxruntime/quantization/operators/base_operator.py +26 -0
- onnxruntime/quantization/operators/binary_op.py +72 -0
- onnxruntime/quantization/operators/concat.py +62 -0
- onnxruntime/quantization/operators/conv.py +260 -0
- onnxruntime/quantization/operators/direct_q8.py +78 -0
- onnxruntime/quantization/operators/embed_layernorm.py +121 -0
- onnxruntime/quantization/operators/gather.py +64 -0
- onnxruntime/quantization/operators/gavgpool.py +62 -0
- onnxruntime/quantization/operators/gemm.py +172 -0
- onnxruntime/quantization/operators/lstm.py +121 -0
- onnxruntime/quantization/operators/matmul.py +231 -0
- onnxruntime/quantization/operators/maxpool.py +34 -0
- onnxruntime/quantization/operators/norm.py +40 -0
- onnxruntime/quantization/operators/pad.py +172 -0
- onnxruntime/quantization/operators/pooling.py +67 -0
- onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
- onnxruntime/quantization/operators/resize.py +34 -0
- onnxruntime/quantization/operators/softmax.py +74 -0
- onnxruntime/quantization/operators/split.py +63 -0
- onnxruntime/quantization/operators/where.py +87 -0
- onnxruntime/quantization/preprocess.py +141 -0
- onnxruntime/quantization/qdq_loss_debug.py +389 -0
- onnxruntime/quantization/qdq_quantizer.py +1477 -0
- onnxruntime/quantization/quant_utils.py +1051 -0
- onnxruntime/quantization/quantize.py +953 -0
- onnxruntime/quantization/registry.py +110 -0
- onnxruntime/quantization/shape_inference.py +204 -0
- onnxruntime/quantization/static_quantize_runner.py +256 -0
- onnxruntime/quantization/tensor_quant_overrides.py +520 -0
- onnxruntime/tools/__init__.py +10 -0
- onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
- onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
- onnxruntime/tools/file_utils.py +47 -0
- onnxruntime/tools/logger.py +11 -0
- onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
- onnxruntime/tools/mobile_helpers/__init__.py +0 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
- onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
- onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
- onnxruntime/tools/offline_tuning.py +169 -0
- onnxruntime/tools/onnx_model_utils.py +416 -0
- onnxruntime/tools/onnx_randomizer.py +85 -0
- onnxruntime/tools/onnxruntime_test.py +164 -0
- onnxruntime/tools/optimize_onnx_model.py +56 -0
- onnxruntime/tools/ort_format_model/__init__.py +27 -0
- onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
- onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
- onnxruntime/tools/ort_format_model/types.py +85 -0
- onnxruntime/tools/ort_format_model/utils.py +61 -0
- onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
- onnxruntime/tools/pytorch_export_helpers.py +131 -0
- onnxruntime/tools/qdq_helpers/__init__.py +0 -0
- onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
- onnxruntime/tools/qnn/add_trans_cast.py +292 -0
- onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
- onnxruntime/tools/qnn/preprocess.py +165 -0
- onnxruntime/tools/reduced_build_config_parser.py +203 -0
- onnxruntime/tools/remove_initializer_from_input.py +37 -0
- onnxruntime/tools/symbolic_shape_infer.py +3094 -0
- onnxruntime/tools/update_onnx_opset.py +31 -0
- onnxruntime/transformers/__init__.py +8 -0
- onnxruntime/transformers/affinity_helper.py +40 -0
- onnxruntime/transformers/benchmark.py +942 -0
- onnxruntime/transformers/benchmark_helper.py +643 -0
- onnxruntime/transformers/bert_perf_test.py +629 -0
- onnxruntime/transformers/bert_test_data.py +641 -0
- onnxruntime/transformers/compare_bert_results.py +256 -0
- onnxruntime/transformers/constants.py +47 -0
- onnxruntime/transformers/convert_generation.py +3605 -0
- onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
- onnxruntime/transformers/convert_to_packing_mode.py +385 -0
- onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
- onnxruntime/transformers/float16.py +501 -0
- onnxruntime/transformers/fusion_attention.py +1189 -0
- onnxruntime/transformers/fusion_attention_clip.py +340 -0
- onnxruntime/transformers/fusion_attention_sam2.py +533 -0
- onnxruntime/transformers/fusion_attention_unet.py +1307 -0
- onnxruntime/transformers/fusion_attention_vae.py +300 -0
- onnxruntime/transformers/fusion_bart_attention.py +435 -0
- onnxruntime/transformers/fusion_base.py +141 -0
- onnxruntime/transformers/fusion_bias_add.py +57 -0
- onnxruntime/transformers/fusion_biasgelu.py +66 -0
- onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
- onnxruntime/transformers/fusion_conformer_attention.py +222 -0
- onnxruntime/transformers/fusion_constant_fold.py +144 -0
- onnxruntime/transformers/fusion_embedlayer.py +810 -0
- onnxruntime/transformers/fusion_fastgelu.py +492 -0
- onnxruntime/transformers/fusion_gelu.py +258 -0
- onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
- onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
- onnxruntime/transformers/fusion_gpt_attention.py +546 -0
- onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
- onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
- onnxruntime/transformers/fusion_group_norm.py +180 -0
- onnxruntime/transformers/fusion_layernorm.py +489 -0
- onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
- onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
- onnxruntime/transformers/fusion_options.py +340 -0
- onnxruntime/transformers/fusion_qordered_attention.py +420 -0
- onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
- onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
- onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
- onnxruntime/transformers/fusion_quickgelu.py +74 -0
- onnxruntime/transformers/fusion_reshape.py +173 -0
- onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
- onnxruntime/transformers/fusion_shape.py +109 -0
- onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
- onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
- onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
- onnxruntime/transformers/fusion_transpose.py +167 -0
- onnxruntime/transformers/fusion_utils.py +321 -0
- onnxruntime/transformers/huggingface_models.py +74 -0
- onnxruntime/transformers/import_utils.py +20 -0
- onnxruntime/transformers/io_binding_helper.py +487 -0
- onnxruntime/transformers/large_model_exporter.py +395 -0
- onnxruntime/transformers/machine_info.py +230 -0
- onnxruntime/transformers/metrics.py +163 -0
- onnxruntime/transformers/models/bart/__init__.py +12 -0
- onnxruntime/transformers/models/bart/export.py +98 -0
- onnxruntime/transformers/models/bert/__init__.py +12 -0
- onnxruntime/transformers/models/bert/eval_squad.py +329 -0
- onnxruntime/transformers/models/gpt2/__init__.py +12 -0
- onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
- onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
- onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
- onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
- onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
- onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
- onnxruntime/transformers/models/llama/__init__.py +12 -0
- onnxruntime/transformers/models/llama/benchmark.py +700 -0
- onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
- onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
- onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
- onnxruntime/transformers/models/llama/dist_settings.py +57 -0
- onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
- onnxruntime/transformers/models/llama/llama_parity.py +343 -0
- onnxruntime/transformers/models/llama/llama_torch.py +47 -0
- onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
- onnxruntime/transformers/models/longformer/__init__.py +12 -0
- onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
- onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
- onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
- onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
- onnxruntime/transformers/models/phi2/__init__.py +12 -0
- onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
- onnxruntime/transformers/models/phi2/inference_example.py +414 -0
- onnxruntime/transformers/models/sam2/__init__.py +12 -0
- onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
- onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
- onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
- onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
- onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
- onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
- onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
- onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
- onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
- onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
- onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
- onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
- onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
- onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
- onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
- onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
- onnxruntime/transformers/models/t5/__init__.py +12 -0
- onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
- onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
- onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
- onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
- onnxruntime/transformers/models/t5/t5_helper.py +302 -0
- onnxruntime/transformers/models/whisper/__init__.py +12 -0
- onnxruntime/transformers/models/whisper/benchmark.py +585 -0
- onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
- onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
- onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
- onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
- onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
- onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
- onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
- onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
- onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
- onnxruntime/transformers/onnx_exporter.py +719 -0
- onnxruntime/transformers/onnx_model.py +1636 -0
- onnxruntime/transformers/onnx_model_bart.py +141 -0
- onnxruntime/transformers/onnx_model_bert.py +488 -0
- onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
- onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
- onnxruntime/transformers/onnx_model_clip.py +42 -0
- onnxruntime/transformers/onnx_model_conformer.py +32 -0
- onnxruntime/transformers/onnx_model_gpt2.py +101 -0
- onnxruntime/transformers/onnx_model_mmdit.py +112 -0
- onnxruntime/transformers/onnx_model_phi.py +929 -0
- onnxruntime/transformers/onnx_model_sam2.py +137 -0
- onnxruntime/transformers/onnx_model_t5.py +985 -0
- onnxruntime/transformers/onnx_model_tnlr.py +226 -0
- onnxruntime/transformers/onnx_model_unet.py +258 -0
- onnxruntime/transformers/onnx_model_vae.py +42 -0
- onnxruntime/transformers/onnx_utils.py +55 -0
- onnxruntime/transformers/optimizer.py +620 -0
- onnxruntime/transformers/past_helper.py +149 -0
- onnxruntime/transformers/profile_result_processor.py +358 -0
- onnxruntime/transformers/profiler.py +434 -0
- onnxruntime/transformers/quantize_helper.py +76 -0
- onnxruntime/transformers/shape_infer_helper.py +121 -0
- onnxruntime/transformers/shape_optimizer.py +400 -0
- onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
- onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
- onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
- onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
- onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
- onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,226 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License.
|
|
4
|
+
# --------------------------------------------------------------------------
|
|
5
|
+
import logging
|
|
6
|
+
|
|
7
|
+
from fusion_attention import AttentionMask, FusionAttention
|
|
8
|
+
from fusion_utils import NumpyHelper
|
|
9
|
+
from onnx import NodeProto, helper
|
|
10
|
+
from onnx_model import OnnxModel
|
|
11
|
+
from onnx_model_bert import BertOnnxModel
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class FusionTnlrAttention(FusionAttention):
|
|
17
|
+
"""
|
|
18
|
+
Fuse TNLR Attention subgraph into one Attention node.
|
|
19
|
+
TNLR Attention has extra addition after qk nodes and adopts [S, B, NH] as I/O shape.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
model: OnnxModel,
|
|
25
|
+
hidden_size: int,
|
|
26
|
+
num_heads: int,
|
|
27
|
+
attention_mask: AttentionMask,
|
|
28
|
+
):
|
|
29
|
+
super().__init__(model, hidden_size, num_heads, attention_mask)
|
|
30
|
+
|
|
31
|
+
def create_attention_node(
|
|
32
|
+
self,
|
|
33
|
+
mask_index: str,
|
|
34
|
+
matmul: NodeProto,
|
|
35
|
+
add: NodeProto,
|
|
36
|
+
num_heads: int,
|
|
37
|
+
hidden_size: int,
|
|
38
|
+
input: str,
|
|
39
|
+
output: str,
|
|
40
|
+
add_qk_str: str,
|
|
41
|
+
) -> NodeProto | None:
|
|
42
|
+
assert num_heads > 0
|
|
43
|
+
if hidden_size > 0 and (hidden_size % num_heads) != 0:
|
|
44
|
+
logger.debug(f"input hidden size {hidden_size} is not a multiple of num of heads {num_heads}")
|
|
45
|
+
return None
|
|
46
|
+
|
|
47
|
+
weight = self.model.get_initializer(matmul.input[1])
|
|
48
|
+
bias = self.model.get_initializer(add.input[1]) or self.model.get_initializer(add.input[0])
|
|
49
|
+
|
|
50
|
+
if weight is None or bias is None:
|
|
51
|
+
return None
|
|
52
|
+
|
|
53
|
+
qkv_weight = NumpyHelper.to_array(weight)
|
|
54
|
+
qkv_bias = NumpyHelper.to_array(bias)
|
|
55
|
+
|
|
56
|
+
attention_node_name = self.model.create_node_name("Attention")
|
|
57
|
+
|
|
58
|
+
tensor_dtype = weight.data_type
|
|
59
|
+
np_type = helper.tensor_dtype_to_np_dtype(tensor_dtype)
|
|
60
|
+
weight = helper.make_tensor(
|
|
61
|
+
name=attention_node_name + "_qkv_weight",
|
|
62
|
+
data_type=tensor_dtype,
|
|
63
|
+
dims=[hidden_size, 3 * hidden_size],
|
|
64
|
+
vals=qkv_weight.astype(np_type).tobytes(),
|
|
65
|
+
raw=True,
|
|
66
|
+
)
|
|
67
|
+
self.model.add_initializer(weight, self.this_graph_name)
|
|
68
|
+
|
|
69
|
+
bias = helper.make_tensor(
|
|
70
|
+
name=attention_node_name + "_qkv_bias",
|
|
71
|
+
data_type=tensor_dtype,
|
|
72
|
+
dims=[3 * hidden_size],
|
|
73
|
+
vals=qkv_bias.astype(np_type).tobytes(),
|
|
74
|
+
raw=True,
|
|
75
|
+
)
|
|
76
|
+
self.model.add_initializer(bias, self.this_graph_name)
|
|
77
|
+
|
|
78
|
+
attention_inputs = [
|
|
79
|
+
input,
|
|
80
|
+
attention_node_name + "_qkv_weight",
|
|
81
|
+
attention_node_name + "_qkv_bias",
|
|
82
|
+
]
|
|
83
|
+
if mask_index is not None:
|
|
84
|
+
attention_inputs.append(mask_index)
|
|
85
|
+
else:
|
|
86
|
+
attention_inputs.append("")
|
|
87
|
+
|
|
88
|
+
if add_qk_str is not None:
|
|
89
|
+
attention_inputs.append("")
|
|
90
|
+
attention_inputs.append(add_qk_str)
|
|
91
|
+
|
|
92
|
+
attention_node = helper.make_node(
|
|
93
|
+
"Attention",
|
|
94
|
+
inputs=attention_inputs,
|
|
95
|
+
outputs=[output],
|
|
96
|
+
name=attention_node_name,
|
|
97
|
+
)
|
|
98
|
+
attention_node.domain = "com.microsoft"
|
|
99
|
+
attention_node.attribute.extend([helper.make_attribute("num_heads", num_heads)])
|
|
100
|
+
|
|
101
|
+
return attention_node
|
|
102
|
+
|
|
103
|
+
def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
|
|
104
|
+
# Sometimes we can not fuse skiplayernormalization since the add before layernorm has an output that used by nodes outside skiplayernorm
|
|
105
|
+
# Conceptually we treat add before layernorm as skiplayernorm node since they share the same pattern
|
|
106
|
+
start_node = normalize_node
|
|
107
|
+
if normalize_node.op_type != "SkipLayerNormalization":
|
|
108
|
+
return
|
|
109
|
+
|
|
110
|
+
# SkipLayerNormalization has two inputs, and one of them is the root input for attention.
|
|
111
|
+
qkv_nodes = self.model.match_parent_path(
|
|
112
|
+
start_node,
|
|
113
|
+
["Where", "Add", "MatMul", "Reshape", "Transpose", "MatMul"],
|
|
114
|
+
[1, 1, 1, 0, 0, 0],
|
|
115
|
+
)
|
|
116
|
+
if qkv_nodes is not None:
|
|
117
|
+
(_, _, matmul_below, reshape_qkv, transpose_qkv, matmul_qkv) = qkv_nodes
|
|
118
|
+
else:
|
|
119
|
+
return
|
|
120
|
+
|
|
121
|
+
other_inputs = []
|
|
122
|
+
for _i, input in enumerate(start_node.input):
|
|
123
|
+
if input not in output_name_to_node:
|
|
124
|
+
continue
|
|
125
|
+
|
|
126
|
+
if input == qkv_nodes[0].output[0]:
|
|
127
|
+
continue
|
|
128
|
+
other_inputs.append(input)
|
|
129
|
+
if len(other_inputs) != 1:
|
|
130
|
+
return
|
|
131
|
+
|
|
132
|
+
root_input = other_inputs[0]
|
|
133
|
+
|
|
134
|
+
v_nodes = self.model.match_parent_path(
|
|
135
|
+
matmul_qkv,
|
|
136
|
+
["Transpose", "Reshape", "Slice", "Add", "MatMul"],
|
|
137
|
+
[1, 0, 0, 0, 1],
|
|
138
|
+
)
|
|
139
|
+
if v_nodes is None:
|
|
140
|
+
return
|
|
141
|
+
(_, _, _, add, matmul) = v_nodes
|
|
142
|
+
|
|
143
|
+
upper_nodes = self.model.match_parent_path(matmul, ["Transpose"], [0])
|
|
144
|
+
transpose = upper_nodes[0]
|
|
145
|
+
|
|
146
|
+
qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Add", "MatMul"], [0, 0, 0])
|
|
147
|
+
if qk_nodes is None:
|
|
148
|
+
return
|
|
149
|
+
(_, add_qk, matmul_qk) = qk_nodes
|
|
150
|
+
|
|
151
|
+
q_nodes = self.model.match_parent_path(
|
|
152
|
+
matmul_qk,
|
|
153
|
+
["Mul", "Transpose", "Reshape", "Slice", "Add", "MatMul"],
|
|
154
|
+
[0, 0, 0, 0, 0, 1],
|
|
155
|
+
)
|
|
156
|
+
if q_nodes is None:
|
|
157
|
+
return
|
|
158
|
+
add = q_nodes[-2]
|
|
159
|
+
matmul = q_nodes[-1]
|
|
160
|
+
|
|
161
|
+
k_nodes = self.model.match_parent_path(
|
|
162
|
+
matmul_qk,
|
|
163
|
+
["Transpose", "Reshape", "Slice", "Add", "MatMul"],
|
|
164
|
+
[1, 0, 0, 0, 1],
|
|
165
|
+
)
|
|
166
|
+
if k_nodes is None:
|
|
167
|
+
return
|
|
168
|
+
add = k_nodes[-2]
|
|
169
|
+
matmul = k_nodes[-1]
|
|
170
|
+
|
|
171
|
+
relative_position_bias_nodes = self.model.match_parent_path(add_qk, ["Reshape", "Where"], [1, 0])
|
|
172
|
+
if relative_position_bias_nodes is None:
|
|
173
|
+
return
|
|
174
|
+
|
|
175
|
+
if matmul.input[0] == root_input:
|
|
176
|
+
mask_index = None
|
|
177
|
+
attention_last_node = reshape_qkv
|
|
178
|
+
# number of heads are same for all the paths, hence to create attention node, we pass the q_num_heads
|
|
179
|
+
# the input_hidden_size represents the input hidden size, this is used as needed but hidden sizes for Q, K are extracted appropriately
|
|
180
|
+
new_node = self.create_attention_node(
|
|
181
|
+
mask_index,
|
|
182
|
+
matmul,
|
|
183
|
+
add,
|
|
184
|
+
self.num_heads,
|
|
185
|
+
self.hidden_size,
|
|
186
|
+
root_input,
|
|
187
|
+
attention_last_node.output[0],
|
|
188
|
+
relative_position_bias_nodes[0].input[0],
|
|
189
|
+
)
|
|
190
|
+
if new_node is None:
|
|
191
|
+
return
|
|
192
|
+
|
|
193
|
+
self.nodes_to_add.append(new_node)
|
|
194
|
+
self.node_name_to_graph_name[new_node.name] = self.this_graph_name
|
|
195
|
+
|
|
196
|
+
# Add a transpose node after the attention node
|
|
197
|
+
back_transpose = helper.make_node(
|
|
198
|
+
"Transpose",
|
|
199
|
+
["back_transpose_in_" + new_node.name],
|
|
200
|
+
[new_node.output[0]],
|
|
201
|
+
"back_transpose_" + new_node.name,
|
|
202
|
+
perm=[1, 0, 2],
|
|
203
|
+
)
|
|
204
|
+
self.model.add_node(back_transpose, self.this_graph_name)
|
|
205
|
+
new_node.input[0] = transpose.input[0]
|
|
206
|
+
new_node.output[0] = "back_transpose_in_" + new_node.name
|
|
207
|
+
|
|
208
|
+
self.nodes_to_remove.extend([attention_last_node, transpose_qkv, matmul_qkv])
|
|
209
|
+
self.nodes_to_remove.extend(qk_nodes)
|
|
210
|
+
self.nodes_to_remove.extend(q_nodes)
|
|
211
|
+
self.nodes_to_remove.extend(k_nodes)
|
|
212
|
+
self.nodes_to_remove.extend(v_nodes)
|
|
213
|
+
|
|
214
|
+
# Use prune graph to remove mask nodes since they are shared by all attention nodes.
|
|
215
|
+
# self.nodes_to_remove.extend(mask_nodes)
|
|
216
|
+
self.prune_graph = True
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
class TnlrOnnxModel(BertOnnxModel):
|
|
220
|
+
def __init__(self, model, num_heads, hidden_size):
|
|
221
|
+
super().__init__(model, num_heads, hidden_size)
|
|
222
|
+
self.attention_mask = AttentionMask(self)
|
|
223
|
+
self.attention_fusion = FusionTnlrAttention(self, self.hidden_size, self.num_heads, self.attention_mask)
|
|
224
|
+
|
|
225
|
+
def fuse_attention(self):
|
|
226
|
+
self.attention_fusion.apply()
|
|
@@ -0,0 +1,258 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License.
|
|
4
|
+
# --------------------------------------------------------------------------
|
|
5
|
+
|
|
6
|
+
import logging
|
|
7
|
+
|
|
8
|
+
from fusion_attention_unet import FusionAttentionUnet
|
|
9
|
+
from fusion_bias_add import FusionBiasAdd
|
|
10
|
+
from fusion_biassplitgelu import FusionBiasSplitGelu
|
|
11
|
+
from fusion_group_norm import FusionGroupNorm
|
|
12
|
+
from fusion_nhwc_conv import FusionNhwcConv
|
|
13
|
+
from fusion_options import FusionOptions
|
|
14
|
+
from fusion_skip_group_norm import FusionSkipGroupNorm
|
|
15
|
+
from fusion_transpose import FusionInsertTranspose, FusionTranspose
|
|
16
|
+
from import_utils import is_installed
|
|
17
|
+
from onnx import ModelProto
|
|
18
|
+
from onnx_model import OnnxModel
|
|
19
|
+
from onnx_model_bert import BertOnnxModel
|
|
20
|
+
|
|
21
|
+
logger = logging.getLogger(__name__)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class UnetOnnxModel(BertOnnxModel):
|
|
25
|
+
def __init__(self, model: ModelProto, num_heads: int = 0, hidden_size: int = 0):
|
|
26
|
+
"""Initialize UNet ONNX Model.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
model (ModelProto): the ONNX model
|
|
30
|
+
num_heads (int, optional): number of attention heads. Defaults to 0 (detect the parameter automatically).
|
|
31
|
+
hidden_size (int, optional): hidden dimension. Defaults to 0 (detect the parameter automatically).
|
|
32
|
+
"""
|
|
33
|
+
assert (num_heads == 0 and hidden_size == 0) or (num_heads > 0 and hidden_size % num_heads == 0)
|
|
34
|
+
|
|
35
|
+
super().__init__(model, num_heads=num_heads, hidden_size=hidden_size)
|
|
36
|
+
|
|
37
|
+
def preprocess(self):
|
|
38
|
+
self.remove_useless_div()
|
|
39
|
+
|
|
40
|
+
def postprocess(self):
|
|
41
|
+
self.prune_graph()
|
|
42
|
+
self.remove_unused_constant()
|
|
43
|
+
|
|
44
|
+
def remove_useless_div(self):
|
|
45
|
+
"""Remove Div by 1"""
|
|
46
|
+
div_nodes = [node for node in self.nodes() if node.op_type == "Div"]
|
|
47
|
+
|
|
48
|
+
nodes_to_remove = []
|
|
49
|
+
for div in div_nodes:
|
|
50
|
+
if self.find_constant_input(div, 1.0) == 1:
|
|
51
|
+
nodes_to_remove.append(div)
|
|
52
|
+
|
|
53
|
+
for node in nodes_to_remove:
|
|
54
|
+
self.replace_input_of_all_nodes(node.output[0], node.input[0])
|
|
55
|
+
|
|
56
|
+
if nodes_to_remove:
|
|
57
|
+
self.remove_nodes(nodes_to_remove)
|
|
58
|
+
logger.info("Removed %d Div nodes", len(nodes_to_remove))
|
|
59
|
+
|
|
60
|
+
def convert_conv_to_nhwc(self):
|
|
61
|
+
# Transpose weights in offline might help since ORT does not apply constant-folding on Transpose nodes.
|
|
62
|
+
conv_to_nhwc_conv = FusionNhwcConv(self, update_weight=True)
|
|
63
|
+
conv_to_nhwc_conv.apply()
|
|
64
|
+
|
|
65
|
+
def merge_adjacent_transpose(self):
|
|
66
|
+
fusion_transpose = FusionTranspose(self)
|
|
67
|
+
fusion_transpose.apply()
|
|
68
|
+
|
|
69
|
+
remove_count = 0
|
|
70
|
+
nodes = self.get_nodes_by_op_type("Transpose")
|
|
71
|
+
for node in nodes:
|
|
72
|
+
permutation = OnnxModel.get_node_attribute(node, "perm")
|
|
73
|
+
assert isinstance(permutation, list)
|
|
74
|
+
if permutation != list(range(len(permutation))):
|
|
75
|
+
continue
|
|
76
|
+
assert not (
|
|
77
|
+
self.find_graph_output(node.output[0])
|
|
78
|
+
or self.find_graph_input(node.input[0])
|
|
79
|
+
or self.find_graph_output(node.input[0])
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
# Let all children nodes skip current Transpose node and link to its parent
|
|
83
|
+
# Note that we cannot update parent node output since parent node might have more than one children.
|
|
84
|
+
self.replace_input_of_all_nodes(node.output[0], node.input[0])
|
|
85
|
+
|
|
86
|
+
self.remove_node(node)
|
|
87
|
+
remove_count += 1
|
|
88
|
+
|
|
89
|
+
total = len(fusion_transpose.nodes_to_remove) + remove_count
|
|
90
|
+
if total:
|
|
91
|
+
logger.info("Removed %d Transpose nodes", total)
|
|
92
|
+
|
|
93
|
+
def fuse_multi_head_attention(self, options: FusionOptions | None = None):
|
|
94
|
+
# Self Attention
|
|
95
|
+
enable_packed_qkv = (options is None) or options.enable_packed_qkv
|
|
96
|
+
self_attention_fusion = FusionAttentionUnet(
|
|
97
|
+
self,
|
|
98
|
+
self.hidden_size,
|
|
99
|
+
self.num_heads,
|
|
100
|
+
is_cross_attention=False,
|
|
101
|
+
enable_packed_qkv=enable_packed_qkv,
|
|
102
|
+
enable_packed_kv=False,
|
|
103
|
+
)
|
|
104
|
+
self_attention_fusion.apply()
|
|
105
|
+
|
|
106
|
+
# Cross Attention
|
|
107
|
+
enable_packed_kv = (options is None) or options.enable_packed_kv
|
|
108
|
+
cross_attention_fusion = FusionAttentionUnet(
|
|
109
|
+
self,
|
|
110
|
+
self.hidden_size,
|
|
111
|
+
self.num_heads,
|
|
112
|
+
is_cross_attention=True,
|
|
113
|
+
enable_packed_qkv=False,
|
|
114
|
+
enable_packed_kv=enable_packed_kv,
|
|
115
|
+
)
|
|
116
|
+
cross_attention_fusion.apply()
|
|
117
|
+
|
|
118
|
+
def fuse_bias_add(self):
|
|
119
|
+
fusion = FusionBiasAdd(self)
|
|
120
|
+
fusion.apply()
|
|
121
|
+
|
|
122
|
+
def optimize(self, options: FusionOptions | None = None):
|
|
123
|
+
if is_installed("tqdm"):
|
|
124
|
+
import tqdm # noqa: PLC0415
|
|
125
|
+
from tqdm.contrib.logging import logging_redirect_tqdm # noqa: PLC0415
|
|
126
|
+
|
|
127
|
+
with logging_redirect_tqdm():
|
|
128
|
+
steps = 18
|
|
129
|
+
progress_bar = tqdm.tqdm(range(steps), initial=0, desc="fusion")
|
|
130
|
+
self._optimize(options, progress_bar)
|
|
131
|
+
else:
|
|
132
|
+
logger.info("tqdm is not installed. Run optimization without progress bar")
|
|
133
|
+
self._optimize(options, None)
|
|
134
|
+
|
|
135
|
+
def _optimize(self, options: FusionOptions | None = None, progress_bar=None):
|
|
136
|
+
if (options is not None) and not options.enable_shape_inference:
|
|
137
|
+
self.disable_shape_inference()
|
|
138
|
+
|
|
139
|
+
self.utils.remove_identity_nodes()
|
|
140
|
+
if progress_bar:
|
|
141
|
+
progress_bar.update(1)
|
|
142
|
+
|
|
143
|
+
# Remove cast nodes that having same data type of input and output based on symbolic shape inference.
|
|
144
|
+
self.utils.remove_useless_cast_nodes()
|
|
145
|
+
if progress_bar:
|
|
146
|
+
progress_bar.update(1)
|
|
147
|
+
|
|
148
|
+
if (options is None) or options.enable_layer_norm:
|
|
149
|
+
self.fuse_layer_norm()
|
|
150
|
+
if progress_bar:
|
|
151
|
+
progress_bar.update(1)
|
|
152
|
+
|
|
153
|
+
if (options is None) or options.enable_gelu:
|
|
154
|
+
self.fuse_gelu()
|
|
155
|
+
if progress_bar:
|
|
156
|
+
progress_bar.update(1)
|
|
157
|
+
|
|
158
|
+
self.preprocess()
|
|
159
|
+
if progress_bar:
|
|
160
|
+
progress_bar.update(1)
|
|
161
|
+
|
|
162
|
+
self.fuse_reshape()
|
|
163
|
+
if progress_bar:
|
|
164
|
+
progress_bar.update(1)
|
|
165
|
+
|
|
166
|
+
if (options is None) or options.enable_group_norm:
|
|
167
|
+
channels_last = (options is None) or options.group_norm_channels_last
|
|
168
|
+
group_norm_fusion = FusionGroupNorm(self, channels_last)
|
|
169
|
+
group_norm_fusion.apply()
|
|
170
|
+
|
|
171
|
+
insert_transpose_fusion = FusionInsertTranspose(self)
|
|
172
|
+
insert_transpose_fusion.apply()
|
|
173
|
+
if progress_bar:
|
|
174
|
+
progress_bar.update(1)
|
|
175
|
+
|
|
176
|
+
if (options is None) or options.enable_bias_splitgelu:
|
|
177
|
+
bias_split_gelu_fusion = FusionBiasSplitGelu(self)
|
|
178
|
+
bias_split_gelu_fusion.apply()
|
|
179
|
+
if progress_bar:
|
|
180
|
+
progress_bar.update(1)
|
|
181
|
+
|
|
182
|
+
if (options is None) or options.enable_attention:
|
|
183
|
+
# self.save_model_to_file("before_mha.onnx")
|
|
184
|
+
self.fuse_multi_head_attention(options)
|
|
185
|
+
if progress_bar:
|
|
186
|
+
progress_bar.update(1)
|
|
187
|
+
|
|
188
|
+
if (options is None) or options.enable_skip_layer_norm:
|
|
189
|
+
self.fuse_skip_layer_norm()
|
|
190
|
+
if progress_bar:
|
|
191
|
+
progress_bar.update(1)
|
|
192
|
+
|
|
193
|
+
self.fuse_shape()
|
|
194
|
+
if progress_bar:
|
|
195
|
+
progress_bar.update(1)
|
|
196
|
+
|
|
197
|
+
# Remove reshape nodes that having same shape of input and output based on symbolic shape inference.
|
|
198
|
+
self.utils.remove_useless_reshape_nodes()
|
|
199
|
+
if progress_bar:
|
|
200
|
+
progress_bar.update(1)
|
|
201
|
+
|
|
202
|
+
if (options is None) or options.enable_skip_group_norm:
|
|
203
|
+
skip_group_norm_fusion = FusionSkipGroupNorm(self)
|
|
204
|
+
skip_group_norm_fusion.apply()
|
|
205
|
+
if progress_bar:
|
|
206
|
+
progress_bar.update(1)
|
|
207
|
+
|
|
208
|
+
if (options is None) or options.enable_bias_skip_layer_norm:
|
|
209
|
+
# Fuse SkipLayerNormalization and Add Bias before it.
|
|
210
|
+
self.fuse_add_bias_skip_layer_norm()
|
|
211
|
+
if progress_bar:
|
|
212
|
+
progress_bar.update(1)
|
|
213
|
+
|
|
214
|
+
if options is not None and options.enable_gelu_approximation:
|
|
215
|
+
self.gelu_approximation()
|
|
216
|
+
if progress_bar:
|
|
217
|
+
progress_bar.update(1)
|
|
218
|
+
|
|
219
|
+
if options is None or options.enable_nhwc_conv:
|
|
220
|
+
self.convert_conv_to_nhwc()
|
|
221
|
+
self.merge_adjacent_transpose()
|
|
222
|
+
if progress_bar:
|
|
223
|
+
progress_bar.update(1)
|
|
224
|
+
|
|
225
|
+
if options is not None and options.enable_bias_add:
|
|
226
|
+
self.fuse_bias_add()
|
|
227
|
+
if progress_bar:
|
|
228
|
+
progress_bar.update(1)
|
|
229
|
+
|
|
230
|
+
self.postprocess()
|
|
231
|
+
if progress_bar:
|
|
232
|
+
progress_bar.update(1)
|
|
233
|
+
|
|
234
|
+
logger.info(f"opset version: {self.get_opset_version()}")
|
|
235
|
+
|
|
236
|
+
def get_fused_operator_statistics(self):
|
|
237
|
+
"""
|
|
238
|
+
Returns node count of fused operators.
|
|
239
|
+
"""
|
|
240
|
+
op_count = {}
|
|
241
|
+
ops = [
|
|
242
|
+
"Attention",
|
|
243
|
+
"MultiHeadAttention",
|
|
244
|
+
"LayerNormalization",
|
|
245
|
+
"SkipLayerNormalization",
|
|
246
|
+
"BiasSplitGelu",
|
|
247
|
+
"GroupNorm",
|
|
248
|
+
"SkipGroupNorm",
|
|
249
|
+
"NhwcConv",
|
|
250
|
+
"BiasAdd",
|
|
251
|
+
]
|
|
252
|
+
|
|
253
|
+
for op in ops:
|
|
254
|
+
nodes = self.get_nodes_by_op_type(op)
|
|
255
|
+
op_count[op] = len(nodes)
|
|
256
|
+
|
|
257
|
+
logger.info(f"Optimized operators:{op_count}")
|
|
258
|
+
return op_count
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License.
|
|
4
|
+
# --------------------------------------------------------------------------
|
|
5
|
+
|
|
6
|
+
from logging import getLogger
|
|
7
|
+
|
|
8
|
+
from fusion_attention_vae import FusionAttentionVae
|
|
9
|
+
from fusion_options import FusionOptions
|
|
10
|
+
from onnx import ModelProto
|
|
11
|
+
from onnx_model_unet import UnetOnnxModel
|
|
12
|
+
|
|
13
|
+
logger = getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class VaeOnnxModel(UnetOnnxModel):
|
|
17
|
+
def __init__(self, model: ModelProto, num_heads: int = 0, hidden_size: int = 0):
|
|
18
|
+
assert (num_heads == 0 and hidden_size == 0) or (num_heads > 0 and hidden_size % num_heads == 0)
|
|
19
|
+
super().__init__(model, num_heads=num_heads, hidden_size=hidden_size)
|
|
20
|
+
|
|
21
|
+
def fuse_multi_head_attention(self, options: FusionOptions | None = None):
|
|
22
|
+
# Self Attention
|
|
23
|
+
self_attention_fusion = FusionAttentionVae(self, self.hidden_size, self.num_heads)
|
|
24
|
+
self_attention_fusion.apply()
|
|
25
|
+
|
|
26
|
+
def get_fused_operator_statistics(self):
|
|
27
|
+
"""
|
|
28
|
+
Returns node count of fused operators.
|
|
29
|
+
"""
|
|
30
|
+
op_count = {}
|
|
31
|
+
ops = [
|
|
32
|
+
"Attention",
|
|
33
|
+
"GroupNorm",
|
|
34
|
+
"SkipGroupNorm",
|
|
35
|
+
"NhwcConv",
|
|
36
|
+
]
|
|
37
|
+
for op in ops:
|
|
38
|
+
nodes = self.get_nodes_by_op_type(op)
|
|
39
|
+
op_count[op] = len(nodes)
|
|
40
|
+
|
|
41
|
+
logger.info(f"Optimized operators:{op_count}")
|
|
42
|
+
return op_count
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License.
|
|
4
|
+
# --------------------------------------------------------------------------
|
|
5
|
+
from fusion_utils import NumpyHelper
|
|
6
|
+
from onnx import ModelProto, TensorProto
|
|
7
|
+
from onnx.external_data_helper import set_external_data
|
|
8
|
+
from onnx_model import OnnxModel
|
|
9
|
+
|
|
10
|
+
from onnxruntime import OrtValue
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def extract_raw_data_from_model(model: ModelProto):
|
|
14
|
+
"""
|
|
15
|
+
Extract external data from model and return the external data as a list of tuples (name, value).
|
|
16
|
+
Note this function does not handle external data that is not loaded into the model as raw data.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
model (ModelProto): the model proto to extract external data from.
|
|
20
|
+
Returns:
|
|
21
|
+
(external_names, external_values): a tuple of two lists of external data names and values.
|
|
22
|
+
"""
|
|
23
|
+
external_data = []
|
|
24
|
+
onnx_model = OnnxModel(model)
|
|
25
|
+
for graph in onnx_model.graphs():
|
|
26
|
+
for initializer in graph.initializer:
|
|
27
|
+
name = initializer.name
|
|
28
|
+
|
|
29
|
+
if initializer.HasField("raw_data"):
|
|
30
|
+
numpy_tensor = NumpyHelper.to_array(initializer)
|
|
31
|
+
ort_value = OrtValue.ortvalue_from_numpy(numpy_tensor)
|
|
32
|
+
external_data.append((name, ort_value))
|
|
33
|
+
# mimic set_external_data
|
|
34
|
+
set_external_data(initializer, location="foo.bin")
|
|
35
|
+
initializer.name = name
|
|
36
|
+
initializer.ClearField("raw_data")
|
|
37
|
+
|
|
38
|
+
return zip(*external_data, strict=False)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def has_external_data(model: ModelProto):
|
|
42
|
+
"""
|
|
43
|
+
Check if the model has external data.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
model (ModelProto): the model proto to check for external data.
|
|
47
|
+
Returns:
|
|
48
|
+
bool: True if the model has external data, False otherwise.
|
|
49
|
+
"""
|
|
50
|
+
onnx_model = OnnxModel(model)
|
|
51
|
+
for graph in onnx_model.graphs():
|
|
52
|
+
for initializer in graph.initializer:
|
|
53
|
+
if initializer.HasField("data_location") and initializer.data_location == TensorProto.EXTERNAL:
|
|
54
|
+
return True
|
|
55
|
+
return False
|