onnxruntime-directml 1.20.0__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxruntime/LICENSE +21 -0
- onnxruntime/Privacy.md +21 -0
- onnxruntime/ThirdPartyNotices.txt +6508 -0
- onnxruntime/__init__.py +78 -0
- onnxruntime/backend/__init__.py +6 -0
- onnxruntime/backend/backend.py +174 -0
- onnxruntime/backend/backend_rep.py +53 -0
- onnxruntime/capi/DirectML.dll +0 -0
- onnxruntime/capi/__init__.py +4 -0
- onnxruntime/capi/_ld_preload.py +7 -0
- onnxruntime/capi/_pybind_state.py +33 -0
- onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
- onnxruntime/capi/onnxruntime.dll +0 -0
- onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
- onnxruntime/capi/onnxruntime_inference_collection.py +1108 -0
- onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
- onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
- onnxruntime/capi/onnxruntime_validation.py +150 -0
- onnxruntime/capi/version_info.py +2 -0
- onnxruntime/datasets/__init__.py +17 -0
- onnxruntime/datasets/logreg_iris.onnx +0 -0
- onnxruntime/datasets/mul_1.onnx +0 -0
- onnxruntime/datasets/sigmoid.onnx +13 -0
- onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
- onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
- onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
- onnxruntime/quantization/__init__.py +16 -0
- onnxruntime/quantization/base_quantizer.py +532 -0
- onnxruntime/quantization/calibrate.py +1245 -0
- onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
- onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
- onnxruntime/quantization/execution_providers/qnn/preprocess.py +307 -0
- onnxruntime/quantization/execution_providers/qnn/quant_config.py +387 -0
- onnxruntime/quantization/fusions/__init__.py +3 -0
- onnxruntime/quantization/fusions/fusion.py +311 -0
- onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
- onnxruntime/quantization/fusions/fusion_layernorm.py +135 -0
- onnxruntime/quantization/matmul_4bits_quantizer.py +1480 -0
- onnxruntime/quantization/matmul_bnb4_quantizer.py +240 -0
- onnxruntime/quantization/onnx_model.py +580 -0
- onnxruntime/quantization/onnx_quantizer.py +1008 -0
- onnxruntime/quantization/operators/__init__.py +2 -0
- onnxruntime/quantization/operators/activation.py +119 -0
- onnxruntime/quantization/operators/argmax.py +18 -0
- onnxruntime/quantization/operators/attention.py +73 -0
- onnxruntime/quantization/operators/base_operator.py +26 -0
- onnxruntime/quantization/operators/binary_op.py +72 -0
- onnxruntime/quantization/operators/concat.py +62 -0
- onnxruntime/quantization/operators/conv.py +258 -0
- onnxruntime/quantization/operators/direct_q8.py +78 -0
- onnxruntime/quantization/operators/embed_layernorm.py +121 -0
- onnxruntime/quantization/operators/gather.py +64 -0
- onnxruntime/quantization/operators/gavgpool.py +62 -0
- onnxruntime/quantization/operators/gemm.py +166 -0
- onnxruntime/quantization/operators/lstm.py +117 -0
- onnxruntime/quantization/operators/matmul.py +231 -0
- onnxruntime/quantization/operators/maxpool.py +34 -0
- onnxruntime/quantization/operators/norm.py +40 -0
- onnxruntime/quantization/operators/pad.py +100 -0
- onnxruntime/quantization/operators/pooling.py +67 -0
- onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
- onnxruntime/quantization/operators/resize.py +34 -0
- onnxruntime/quantization/operators/softmax.py +74 -0
- onnxruntime/quantization/operators/split.py +63 -0
- onnxruntime/quantization/operators/where.py +87 -0
- onnxruntime/quantization/preprocess.py +141 -0
- onnxruntime/quantization/qdq_loss_debug.py +389 -0
- onnxruntime/quantization/qdq_quantizer.py +1187 -0
- onnxruntime/quantization/quant_utils.py +891 -0
- onnxruntime/quantization/quantize.py +748 -0
- onnxruntime/quantization/registry.py +106 -0
- onnxruntime/quantization/shape_inference.py +187 -0
- onnxruntime/quantization/tensor_quant_overrides.py +516 -0
- onnxruntime/tools/__init__.py +10 -0
- onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
- onnxruntime/tools/convert_onnx_models_to_ort.py +377 -0
- onnxruntime/tools/file_utils.py +46 -0
- onnxruntime/tools/logger.py +11 -0
- onnxruntime/tools/make_dynamic_shape_fixed.py +72 -0
- onnxruntime/tools/mobile_helpers/__init__.py +0 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +33 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
- onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
- onnxruntime/tools/mobile_helpers/usability_checker.py +739 -0
- onnxruntime/tools/offline_tuning.py +169 -0
- onnxruntime/tools/onnx_model_utils.py +413 -0
- onnxruntime/tools/onnx_randomizer.py +85 -0
- onnxruntime/tools/onnxruntime_test.py +164 -0
- onnxruntime/tools/optimize_onnx_model.py +55 -0
- onnxruntime/tools/ort_format_model/__init__.py +25 -0
- onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +663 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
- onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
- onnxruntime/tools/ort_format_model/types.py +84 -0
- onnxruntime/tools/ort_format_model/utils.py +62 -0
- onnxruntime/tools/pytorch_export_contrib_ops.py +108 -0
- onnxruntime/tools/pytorch_export_helpers.py +131 -0
- onnxruntime/tools/qdq_helpers/__init__.py +0 -0
- onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
- onnxruntime/tools/reduced_build_config_parser.py +202 -0
- onnxruntime/tools/symbolic_shape_infer.py +3016 -0
- onnxruntime/tools/update_onnx_opset.py +31 -0
- onnxruntime/transformers/__init__.py +8 -0
- onnxruntime/transformers/affinity_helper.py +40 -0
- onnxruntime/transformers/benchmark.py +944 -0
- onnxruntime/transformers/benchmark_helper.py +646 -0
- onnxruntime/transformers/bert_perf_test.py +634 -0
- onnxruntime/transformers/bert_test_data.py +642 -0
- onnxruntime/transformers/compare_bert_results.py +246 -0
- onnxruntime/transformers/constants.py +47 -0
- onnxruntime/transformers/convert_generation.py +3124 -0
- onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
- onnxruntime/transformers/convert_to_packing_mode.py +387 -0
- onnxruntime/transformers/dynamo_onnx_helper.py +104 -0
- onnxruntime/transformers/float16.py +501 -0
- onnxruntime/transformers/fusion_attention.py +1235 -0
- onnxruntime/transformers/fusion_attention_clip.py +257 -0
- onnxruntime/transformers/fusion_attention_sam2.py +534 -0
- onnxruntime/transformers/fusion_attention_unet.py +1304 -0
- onnxruntime/transformers/fusion_attention_vae.py +301 -0
- onnxruntime/transformers/fusion_bart_attention.py +640 -0
- onnxruntime/transformers/fusion_base.py +137 -0
- onnxruntime/transformers/fusion_bias_add.py +58 -0
- onnxruntime/transformers/fusion_biasgelu.py +66 -0
- onnxruntime/transformers/fusion_biassplitgelu.py +111 -0
- onnxruntime/transformers/fusion_conformer_attention.py +143 -0
- onnxruntime/transformers/fusion_embedlayer.py +811 -0
- onnxruntime/transformers/fusion_fastgelu.py +360 -0
- onnxruntime/transformers/fusion_gelu.py +259 -0
- onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
- onnxruntime/transformers/fusion_gemmfastgelu.py +122 -0
- onnxruntime/transformers/fusion_gpt_attention.py +546 -0
- onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
- onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
- onnxruntime/transformers/fusion_group_norm.py +179 -0
- onnxruntime/transformers/fusion_layernorm.py +465 -0
- onnxruntime/transformers/fusion_nhwc_conv.py +100 -0
- onnxruntime/transformers/fusion_options.py +340 -0
- onnxruntime/transformers/fusion_qordered_attention.py +421 -0
- onnxruntime/transformers/fusion_qordered_gelu.py +119 -0
- onnxruntime/transformers/fusion_qordered_layernorm.py +123 -0
- onnxruntime/transformers/fusion_qordered_matmul.py +217 -0
- onnxruntime/transformers/fusion_quickgelu.py +74 -0
- onnxruntime/transformers/fusion_reshape.py +173 -0
- onnxruntime/transformers/fusion_rotary_attention.py +1592 -0
- onnxruntime/transformers/fusion_shape.py +110 -0
- onnxruntime/transformers/fusion_simplified_layernorm.py +159 -0
- onnxruntime/transformers/fusion_skip_group_norm.py +255 -0
- onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
- onnxruntime/transformers/fusion_transpose.py +168 -0
- onnxruntime/transformers/fusion_utils.py +307 -0
- onnxruntime/transformers/huggingface_models.py +167 -0
- onnxruntime/transformers/import_utils.py +20 -0
- onnxruntime/transformers/io_binding_helper.py +442 -0
- onnxruntime/transformers/large_model_exporter.py +395 -0
- onnxruntime/transformers/machine_info.py +221 -0
- onnxruntime/transformers/metrics.py +164 -0
- onnxruntime/transformers/models/bart/__init__.py +12 -0
- onnxruntime/transformers/models/bart/export.py +98 -0
- onnxruntime/transformers/models/bert/__init__.py +12 -0
- onnxruntime/transformers/models/bert/eval_squad.py +329 -0
- onnxruntime/transformers/models/gpt2/__init__.py +12 -0
- onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
- onnxruntime/transformers/models/gpt2/convert_to_onnx.py +561 -0
- onnxruntime/transformers/models/gpt2/gpt2_helper.py +1032 -0
- onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
- onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
- onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
- onnxruntime/transformers/models/llama/__init__.py +12 -0
- onnxruntime/transformers/models/llama/benchmark.py +703 -0
- onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
- onnxruntime/transformers/models/llama/benchmark_e2e.py +606 -0
- onnxruntime/transformers/models/llama/convert_to_onnx.py +1027 -0
- onnxruntime/transformers/models/llama/dist_settings.py +57 -0
- onnxruntime/transformers/models/llama/llama_inputs.py +503 -0
- onnxruntime/transformers/models/llama/llama_parity.py +309 -0
- onnxruntime/transformers/models/llama/llama_torch.py +47 -0
- onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
- onnxruntime/transformers/models/longformer/__init__.py +12 -0
- onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
- onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
- onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
- onnxruntime/transformers/models/longformer/longformer_helper.py +77 -0
- onnxruntime/transformers/models/phi2/__init__.py +12 -0
- onnxruntime/transformers/models/phi2/convert_to_onnx.py +576 -0
- onnxruntime/transformers/models/phi2/inference_example.py +414 -0
- onnxruntime/transformers/models/sam2/__init__.py +12 -0
- onnxruntime/transformers/models/sam2/benchmark_sam2.py +625 -0
- onnxruntime/transformers/models/sam2/convert_to_onnx.py +260 -0
- onnxruntime/transformers/models/sam2/image_decoder.py +273 -0
- onnxruntime/transformers/models/sam2/image_encoder.py +186 -0
- onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
- onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
- onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
- onnxruntime/transformers/models/sam2/sam2_demo.py +322 -0
- onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +280 -0
- onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
- onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark.py +1429 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +102 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +268 -0
- onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1319 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1181 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder.py +296 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +388 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
- onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +350 -0
- onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
- onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
- onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
- onnxruntime/transformers/models/t5/__init__.py +12 -0
- onnxruntime/transformers/models/t5/convert_to_onnx.py +278 -0
- onnxruntime/transformers/models/t5/past_helper.py +150 -0
- onnxruntime/transformers/models/t5/t5_decoder.py +438 -0
- onnxruntime/transformers/models/t5/t5_encoder.py +171 -0
- onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +299 -0
- onnxruntime/transformers/models/t5/t5_helper.py +272 -0
- onnxruntime/transformers/models/whisper/__init__.py +12 -0
- onnxruntime/transformers/models/whisper/benchmark.py +610 -0
- onnxruntime/transformers/models/whisper/benchmark_all.py +528 -0
- onnxruntime/transformers/models/whisper/convert_to_onnx.py +536 -0
- onnxruntime/transformers/models/whisper/whisper_chain.py +329 -0
- onnxruntime/transformers/models/whisper/whisper_decoder.py +402 -0
- onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
- onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +306 -0
- onnxruntime/transformers/models/whisper/whisper_helper.py +524 -0
- onnxruntime/transformers/models/whisper/whisper_openai_helper.py +84 -0
- onnxruntime/transformers/onnx_exporter.py +717 -0
- onnxruntime/transformers/onnx_model.py +1569 -0
- onnxruntime/transformers/onnx_model_bart.py +142 -0
- onnxruntime/transformers/onnx_model_bert.py +481 -0
- onnxruntime/transformers/onnx_model_bert_keras.py +475 -0
- onnxruntime/transformers/onnx_model_bert_tf.py +589 -0
- onnxruntime/transformers/onnx_model_clip.py +40 -0
- onnxruntime/transformers/onnx_model_conformer.py +33 -0
- onnxruntime/transformers/onnx_model_gpt2.py +101 -0
- onnxruntime/transformers/onnx_model_phi.py +930 -0
- onnxruntime/transformers/onnx_model_sam2.py +138 -0
- onnxruntime/transformers/onnx_model_t5.py +791 -0
- onnxruntime/transformers/onnx_model_tnlr.py +227 -0
- onnxruntime/transformers/onnx_model_unet.py +259 -0
- onnxruntime/transformers/onnx_model_vae.py +43 -0
- onnxruntime/transformers/onnx_utils.py +55 -0
- onnxruntime/transformers/optimizer.py +612 -0
- onnxruntime/transformers/profiler.py +725 -0
- onnxruntime/transformers/quantize_helper.py +76 -0
- onnxruntime/transformers/shape_infer_helper.py +122 -0
- onnxruntime/transformers/shape_optimizer.py +401 -0
- onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
- onnxruntime_directml-1.20.0.dist-info/METADATA +187 -0
- onnxruntime_directml-1.20.0.dist-info/RECORD +305 -0
- onnxruntime_directml-1.20.0.dist-info/WHEEL +5 -0
- onnxruntime_directml-1.20.0.dist-info/entry_points.txt +2 -0
- onnxruntime_directml-1.20.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,205 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License.
|
|
4
|
+
# --------------------------------------------------------------------------
|
|
5
|
+
|
|
6
|
+
import glob
|
|
7
|
+
import os
|
|
8
|
+
|
|
9
|
+
import requests
|
|
10
|
+
|
|
11
|
+
TFMODELS = {
|
|
12
|
+
"bert-base-uncased": (
|
|
13
|
+
"bert",
|
|
14
|
+
"BertConfig",
|
|
15
|
+
"",
|
|
16
|
+
"https://storage.googleapis.com/bert_models/2018_10_18/uncased_L-12_H-768_A-12.zip",
|
|
17
|
+
),
|
|
18
|
+
"bert-base-cased": (
|
|
19
|
+
"bert",
|
|
20
|
+
"BertConfig",
|
|
21
|
+
"",
|
|
22
|
+
"https://storage.googleapis.com/bert_models/2019_05_30/wwm_cased_L-24_H-1024_A-16.zip",
|
|
23
|
+
),
|
|
24
|
+
"bert-large-uncased": (
|
|
25
|
+
"bert",
|
|
26
|
+
"BertConfig",
|
|
27
|
+
"",
|
|
28
|
+
"https://storage.googleapis.com/bert_models/2018_10_18/uncased_L-24_H-1024_A-16.zip",
|
|
29
|
+
),
|
|
30
|
+
"albert-base": (
|
|
31
|
+
"albert",
|
|
32
|
+
"AlbertConfig",
|
|
33
|
+
"",
|
|
34
|
+
"https://storage.googleapis.com/albert_models/albert_base_v1.tar.gz",
|
|
35
|
+
),
|
|
36
|
+
"albert-large": (
|
|
37
|
+
"albert",
|
|
38
|
+
"AlbertConfig",
|
|
39
|
+
"",
|
|
40
|
+
"https://storage.googleapis.com/albert_models/albert_large_v1.tar.gz",
|
|
41
|
+
),
|
|
42
|
+
"gpt-2-117M": (
|
|
43
|
+
"gpt2",
|
|
44
|
+
"GPT2Config",
|
|
45
|
+
"GPT2Model",
|
|
46
|
+
"https://storage.googleapis.com/gpt-2/models/117M",
|
|
47
|
+
),
|
|
48
|
+
"gpt-2-124M": (
|
|
49
|
+
"gpt2",
|
|
50
|
+
"GPT2Config",
|
|
51
|
+
"GPT2Model",
|
|
52
|
+
"https://storage.googleapis.com/gpt-2/models/124M",
|
|
53
|
+
),
|
|
54
|
+
}
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def download_compressed_file(tf_ckpt_url, ckpt_dir):
|
|
58
|
+
r = requests.get(tf_ckpt_url)
|
|
59
|
+
compressed_file_name = tf_ckpt_url.split("/")[-1]
|
|
60
|
+
compressed_file_dir = os.path.join(ckpt_dir, compressed_file_name)
|
|
61
|
+
with open(compressed_file_dir, "wb") as f:
|
|
62
|
+
f.write(r.content)
|
|
63
|
+
return compressed_file_dir
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def get_ckpt_prefix_path(ckpt_dir):
|
|
67
|
+
# get prefix
|
|
68
|
+
sub_folder_dir = None
|
|
69
|
+
for o in os.listdir(ckpt_dir):
|
|
70
|
+
sub_folder_dir = os.path.join(ckpt_dir, o)
|
|
71
|
+
break
|
|
72
|
+
if os.path.isfile(sub_folder_dir):
|
|
73
|
+
sub_folder_dir = ckpt_dir
|
|
74
|
+
unique_file_name = str(glob.glob(sub_folder_dir + "/*data-00000-of-00001"))
|
|
75
|
+
prefix = (unique_file_name.rpartition(".")[0]).split("/")[-1]
|
|
76
|
+
|
|
77
|
+
return os.path.join(sub_folder_dir, prefix)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def download_tf_checkpoint(model_name, tf_models_dir="tf_models"):
|
|
81
|
+
import pathlib
|
|
82
|
+
|
|
83
|
+
base_dir = os.path.join(pathlib.Path(__file__).parent.absolute(), tf_models_dir)
|
|
84
|
+
ckpt_dir = os.path.join(base_dir, model_name)
|
|
85
|
+
|
|
86
|
+
if not os.path.exists(ckpt_dir):
|
|
87
|
+
os.makedirs(ckpt_dir)
|
|
88
|
+
|
|
89
|
+
tf_ckpt_url = TFMODELS[model_name][3]
|
|
90
|
+
|
|
91
|
+
import re
|
|
92
|
+
|
|
93
|
+
if re.search(".zip$", tf_ckpt_url) is not None:
|
|
94
|
+
zip_dir = download_compressed_file(tf_ckpt_url, ckpt_dir)
|
|
95
|
+
|
|
96
|
+
# unzip file
|
|
97
|
+
import zipfile
|
|
98
|
+
|
|
99
|
+
with zipfile.ZipFile(zip_dir, "r") as zip_ref:
|
|
100
|
+
zip_ref.extractall(ckpt_dir)
|
|
101
|
+
os.remove(zip_dir)
|
|
102
|
+
|
|
103
|
+
return get_ckpt_prefix_path(ckpt_dir)
|
|
104
|
+
|
|
105
|
+
elif re.search(".tar.gz$", tf_ckpt_url) is not None:
|
|
106
|
+
tar_dir = download_compressed_file(tf_ckpt_url, ckpt_dir)
|
|
107
|
+
|
|
108
|
+
# untar file
|
|
109
|
+
import tarfile
|
|
110
|
+
|
|
111
|
+
with tarfile.open(tar_dir, "r") as tar_ref:
|
|
112
|
+
tar_ref.extractall(ckpt_dir)
|
|
113
|
+
os.remove(tar_dir)
|
|
114
|
+
|
|
115
|
+
return get_ckpt_prefix_path(ckpt_dir)
|
|
116
|
+
|
|
117
|
+
else:
|
|
118
|
+
for filename in [
|
|
119
|
+
"checkpoint",
|
|
120
|
+
"model.ckpt.data-00000-of-00001",
|
|
121
|
+
"model.ckpt.index",
|
|
122
|
+
"model.ckpt.meta",
|
|
123
|
+
]:
|
|
124
|
+
r = requests.get(tf_ckpt_url + "/" + filename)
|
|
125
|
+
with open(os.path.join(ckpt_dir, filename), "wb") as f:
|
|
126
|
+
f.write(r.content)
|
|
127
|
+
|
|
128
|
+
return get_ckpt_prefix_path(ckpt_dir)
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def init_pytorch_model(model_name, tf_checkpoint_path):
|
|
132
|
+
config_name = TFMODELS[model_name][1]
|
|
133
|
+
config_module = __import__("transformers", fromlist=[config_name])
|
|
134
|
+
model_config = getattr(config_module, config_name)
|
|
135
|
+
|
|
136
|
+
parent_path = tf_checkpoint_path.rpartition("/")[0]
|
|
137
|
+
config_path = glob.glob(parent_path + "/*config.json")
|
|
138
|
+
config = model_config() if len(config_path) == 0 else model_config.from_json_file(str(config_path[0]))
|
|
139
|
+
|
|
140
|
+
if not TFMODELS[model_name][2]:
|
|
141
|
+
from transformers import AutoModelForPreTraining
|
|
142
|
+
|
|
143
|
+
init_model = AutoModelForPreTraining.from_config(config)
|
|
144
|
+
else:
|
|
145
|
+
model_categroy_name = TFMODELS[model_name][2]
|
|
146
|
+
module = __import__("transformers", fromlist=[model_categroy_name])
|
|
147
|
+
model_categroy = getattr(module, model_categroy_name)
|
|
148
|
+
init_model = model_categroy(config)
|
|
149
|
+
return config, init_model
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
def convert_tf_checkpoint_to_pytorch(model_name, config, init_model, tf_checkpoint_path, is_tf2):
|
|
153
|
+
load_tf_weight_func_name = "load_tf_weights_in_" + TFMODELS[model_name][0]
|
|
154
|
+
|
|
155
|
+
module = __import__("transformers", fromlist=[load_tf_weight_func_name])
|
|
156
|
+
|
|
157
|
+
if is_tf2 is False:
|
|
158
|
+
load_tf_weight_func = getattr(module, load_tf_weight_func_name)
|
|
159
|
+
else:
|
|
160
|
+
if TFMODELS[model_name][0] != "bert":
|
|
161
|
+
raise NotImplementedError("Only support tf2 ckeckpoint for Bert model")
|
|
162
|
+
from transformers import convert_bert_original_tf2_checkpoint_to_pytorch
|
|
163
|
+
|
|
164
|
+
load_tf_weight_func = convert_bert_original_tf2_checkpoint_to_pytorch.load_tf2_weights_in_bert
|
|
165
|
+
|
|
166
|
+
# Expect transformers team will unify the order of signature in the future
|
|
167
|
+
model = (
|
|
168
|
+
load_tf_weight_func(init_model, config, tf_checkpoint_path)
|
|
169
|
+
if is_tf2 is False
|
|
170
|
+
else load_tf_weight_func(init_model, tf_checkpoint_path, config)
|
|
171
|
+
)
|
|
172
|
+
model.eval()
|
|
173
|
+
return model
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def tf2pt_pipeline(model_name, is_tf2=False):
|
|
177
|
+
if model_name not in TFMODELS:
|
|
178
|
+
raise NotImplementedError(model_name + " not implemented")
|
|
179
|
+
tf_checkpoint_path = download_tf_checkpoint(model_name)
|
|
180
|
+
config, init_model = init_pytorch_model(model_name, tf_checkpoint_path)
|
|
181
|
+
model = convert_tf_checkpoint_to_pytorch(model_name, config, init_model, tf_checkpoint_path, is_tf2)
|
|
182
|
+
# Could then use the model in Benchmark
|
|
183
|
+
return config, model
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
def tf2pt_pipeline_test():
|
|
187
|
+
# For test on linux only
|
|
188
|
+
import logging
|
|
189
|
+
|
|
190
|
+
import torch
|
|
191
|
+
|
|
192
|
+
logger = logging.getLogger("")
|
|
193
|
+
for model_name in TFMODELS:
|
|
194
|
+
config, model = tf2pt_pipeline(model_name)
|
|
195
|
+
assert config.model_type is TFMODELS[model_name][0]
|
|
196
|
+
|
|
197
|
+
input = torch.randint(low=0, high=config.vocab_size - 1, size=(4, 128), dtype=torch.long)
|
|
198
|
+
try:
|
|
199
|
+
model(input)
|
|
200
|
+
except RuntimeError as e:
|
|
201
|
+
logger.exception(e)
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
if __name__ == "__main__":
|
|
205
|
+
tf2pt_pipeline_test()
|
|
@@ -0,0 +1,387 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License.
|
|
4
|
+
# --------------------------------------------------------------------------
|
|
5
|
+
|
|
6
|
+
import argparse
|
|
7
|
+
import logging
|
|
8
|
+
import os
|
|
9
|
+
from typing import List, Union
|
|
10
|
+
|
|
11
|
+
import coloredlogs
|
|
12
|
+
from constants import (
|
|
13
|
+
AttentionInputIDs,
|
|
14
|
+
AttentionOutputIDs,
|
|
15
|
+
MultiHeadAttentionInputIDs,
|
|
16
|
+
MultiHeadAttentionOutputIDs,
|
|
17
|
+
Operators,
|
|
18
|
+
)
|
|
19
|
+
from onnx import helper, load_model
|
|
20
|
+
from onnx_model import NodeProto, OnnxModel
|
|
21
|
+
from shape_infer_helper import SymbolicShapeInferenceHelper
|
|
22
|
+
|
|
23
|
+
logger = logging.getLogger(__name__)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class PackingAttentionBase:
|
|
27
|
+
def __init__(self, model: OnnxModel, attention_op_type: str):
|
|
28
|
+
self.model: OnnxModel = model
|
|
29
|
+
self.nodes_to_remove: List = []
|
|
30
|
+
self.nodes_to_add: List = []
|
|
31
|
+
self.prune_graph: bool = False
|
|
32
|
+
self.node_name_to_graph_name: dict = {}
|
|
33
|
+
self.this_graph_name: str = self.model.model.graph.name
|
|
34
|
+
self.attention_op_type = attention_op_type
|
|
35
|
+
self.attention_nodes = self.model.get_nodes_by_op_type(attention_op_type)
|
|
36
|
+
|
|
37
|
+
def _try_getting_attention_mask(self) -> Union[str, None]:
|
|
38
|
+
mask_index = (
|
|
39
|
+
AttentionInputIDs.MASK_INDEX
|
|
40
|
+
if self.attention_op_type == Operators.ATTENTION
|
|
41
|
+
else MultiHeadAttentionInputIDs.KEY_PADDING_MASK
|
|
42
|
+
)
|
|
43
|
+
first_attention_node = self._try_getting_first_attention()
|
|
44
|
+
# check if attention has mask
|
|
45
|
+
if not first_attention_node or len(first_attention_node.input) <= mask_index:
|
|
46
|
+
return None
|
|
47
|
+
|
|
48
|
+
attention_mask = first_attention_node.input[mask_index]
|
|
49
|
+
|
|
50
|
+
# check if all attention nodes have same mask
|
|
51
|
+
for node in self.attention_nodes:
|
|
52
|
+
if len(node.input) <= mask_index or node.input[mask_index] != attention_mask:
|
|
53
|
+
return None
|
|
54
|
+
|
|
55
|
+
return attention_mask
|
|
56
|
+
|
|
57
|
+
def _try_getting_first_attention(self) -> Union[NodeProto, None]:
|
|
58
|
+
if len(self.attention_nodes) <= 0:
|
|
59
|
+
return None
|
|
60
|
+
|
|
61
|
+
return self.attention_nodes[0]
|
|
62
|
+
|
|
63
|
+
def _try_getting_last_layernorm(self) -> Union[NodeProto, None]:
|
|
64
|
+
last_layernorm_node = None
|
|
65
|
+
for node in self.model.nodes():
|
|
66
|
+
if node.op_type == Operators.LAYERNORM or node.op_type == Operators.SKIPLAYERNORM:
|
|
67
|
+
last_layernorm_node = node
|
|
68
|
+
return last_layernorm_node
|
|
69
|
+
|
|
70
|
+
def _are_attentions_supported(self) -> bool:
|
|
71
|
+
raise NotImplementedError()
|
|
72
|
+
|
|
73
|
+
def _insert_removepadding_node(self, inputs: List[str], outputs: List[str]) -> None:
|
|
74
|
+
new_node = helper.make_node(
|
|
75
|
+
Operators.REMOVEPADDING,
|
|
76
|
+
inputs=inputs,
|
|
77
|
+
outputs=outputs,
|
|
78
|
+
name=self.model.create_node_name(Operators.REMOVEPADDING),
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
new_node.domain = "com.microsoft"
|
|
82
|
+
self.nodes_to_add.append(new_node)
|
|
83
|
+
self.node_name_to_graph_name[new_node.name] = self.this_graph_name
|
|
84
|
+
|
|
85
|
+
def _insert_restorepadding_node(self, inputs: List[str], outputs: List[str]) -> None:
|
|
86
|
+
new_node = helper.make_node(
|
|
87
|
+
Operators.RESTOREPADDING,
|
|
88
|
+
inputs=inputs,
|
|
89
|
+
outputs=outputs,
|
|
90
|
+
name=self.model.create_node_name(Operators.RESTOREPADDING),
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
new_node.domain = "com.microsoft"
|
|
94
|
+
self.nodes_to_add.append(new_node)
|
|
95
|
+
self.node_name_to_graph_name[new_node.name] = self.this_graph_name
|
|
96
|
+
|
|
97
|
+
def _replace_attention_with_packing_attention(self, token_offset: str, cumulative_sequence_length: str) -> None:
|
|
98
|
+
raise NotImplementedError()
|
|
99
|
+
|
|
100
|
+
def _get_input_to_remove_padding(self, first_attention_node) -> Union[str, None]:
|
|
101
|
+
if self.attention_op_type == Operators.ATTENTION:
|
|
102
|
+
return first_attention_node.input[AttentionInputIDs.INPUT]
|
|
103
|
+
return None
|
|
104
|
+
|
|
105
|
+
def convert(self, use_symbolic_shape_infer: bool = True) -> None:
|
|
106
|
+
logger.debug("start converting to packing model...")
|
|
107
|
+
|
|
108
|
+
if not self._are_attentions_supported():
|
|
109
|
+
return
|
|
110
|
+
|
|
111
|
+
attention_mask = self._try_getting_attention_mask()
|
|
112
|
+
if not attention_mask:
|
|
113
|
+
return
|
|
114
|
+
|
|
115
|
+
first_attention_node = self._try_getting_first_attention()
|
|
116
|
+
last_layernorm_node = self._try_getting_last_layernorm()
|
|
117
|
+
if not last_layernorm_node:
|
|
118
|
+
return
|
|
119
|
+
|
|
120
|
+
# insert RemovePadding
|
|
121
|
+
input_to_remove_padding = self._get_input_to_remove_padding(first_attention_node)
|
|
122
|
+
if not input_to_remove_padding:
|
|
123
|
+
return
|
|
124
|
+
|
|
125
|
+
output_without_padding = input_to_remove_padding + "_no_padding"
|
|
126
|
+
token_offset = input_to_remove_padding + "_token_offset"
|
|
127
|
+
cumulated_seq_len = input_to_remove_padding + "_cumulated_seq_len"
|
|
128
|
+
max_seq_len = input_to_remove_padding + "_max_seq_len"
|
|
129
|
+
self._insert_removepadding_node(
|
|
130
|
+
[input_to_remove_padding, attention_mask],
|
|
131
|
+
[output_without_padding, token_offset, cumulated_seq_len, max_seq_len],
|
|
132
|
+
)
|
|
133
|
+
self.model.replace_input_of_all_nodes(input_to_remove_padding, output_without_padding)
|
|
134
|
+
logger.debug("inserted RemovePadding before Attention")
|
|
135
|
+
|
|
136
|
+
# insert RestorePadding
|
|
137
|
+
restorepadding_input = last_layernorm_node.output[0] + "_restore_input"
|
|
138
|
+
self._insert_restorepadding_node([restorepadding_input, token_offset], [last_layernorm_node.output[0]])
|
|
139
|
+
self.model.replace_output_of_all_nodes(last_layernorm_node.output[0], restorepadding_input)
|
|
140
|
+
logger.debug(f"inserted RestorePadding after last {last_layernorm_node.op_type} layer")
|
|
141
|
+
|
|
142
|
+
# insert PackedAttention
|
|
143
|
+
self._replace_attention_with_packing_attention(token_offset, cumulated_seq_len)
|
|
144
|
+
logger.debug(f"replaced {self.attention_op_type} with Packed{self.attention_op_type}")
|
|
145
|
+
|
|
146
|
+
self.model.remove_nodes(self.nodes_to_remove)
|
|
147
|
+
self.model.add_nodes(self.nodes_to_add, self.node_name_to_graph_name)
|
|
148
|
+
|
|
149
|
+
if self.prune_graph:
|
|
150
|
+
self.model.prune_graph()
|
|
151
|
+
elif self.nodes_to_remove or self.nodes_to_add:
|
|
152
|
+
self.model.update_graph()
|
|
153
|
+
self.model.clean_shape_infer()
|
|
154
|
+
if use_symbolic_shape_infer:
|
|
155
|
+
# Use symbolic shape inference since custom operators (like Gelu, SkipLayerNormalization etc)
|
|
156
|
+
# are not recognized by onnx shape inference.
|
|
157
|
+
shape_infer_helper = SymbolicShapeInferenceHelper(self.model.model, verbose=0)
|
|
158
|
+
inferred_model = shape_infer_helper.infer_shapes(self.model.model, auto_merge=True, guess_output_rank=False)
|
|
159
|
+
if inferred_model:
|
|
160
|
+
self.model.model = inferred_model
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
class PackingAttention(PackingAttentionBase):
|
|
164
|
+
def __init__(self, model: OnnxModel):
|
|
165
|
+
super().__init__(model, Operators.ATTENTION)
|
|
166
|
+
|
|
167
|
+
def _are_attentions_supported(self) -> bool:
|
|
168
|
+
for node in self.attention_nodes:
|
|
169
|
+
if OnnxModel.get_node_attribute(node, "past_present_share_buffer") is not None:
|
|
170
|
+
return False
|
|
171
|
+
if OnnxModel.get_node_attribute(node, "do_rotary") is not None:
|
|
172
|
+
return False
|
|
173
|
+
unidirection_attr = OnnxModel.get_node_attribute(node, "unidirectional")
|
|
174
|
+
if unidirection_attr is not None and unidirection_attr != 0:
|
|
175
|
+
return False
|
|
176
|
+
if len(node.input) > AttentionInputIDs.PAST and not node.input[AttentionInputIDs.PAST]:
|
|
177
|
+
return False
|
|
178
|
+
if (
|
|
179
|
+
len(node.input) > AttentionInputIDs.PAST_SEQUENCE_LENGTH
|
|
180
|
+
and not node.input[AttentionInputIDs.PAST_SEQUENCE_LENGTH]
|
|
181
|
+
):
|
|
182
|
+
return False
|
|
183
|
+
return True
|
|
184
|
+
|
|
185
|
+
def _replace_attention_with_packing_attention(self, token_offset: str, cumulative_sequence_length: str) -> None:
|
|
186
|
+
for attention in self.attention_nodes:
|
|
187
|
+
attention_bias = (
|
|
188
|
+
attention.input[AttentionInputIDs.ATTENTION_BIAS]
|
|
189
|
+
if len(attention.input) > AttentionInputIDs.ATTENTION_BIAS
|
|
190
|
+
else ""
|
|
191
|
+
)
|
|
192
|
+
packed_attention = helper.make_node(
|
|
193
|
+
Operators.PACKEDATTENTION,
|
|
194
|
+
inputs=[
|
|
195
|
+
attention.input[AttentionInputIDs.INPUT],
|
|
196
|
+
attention.input[AttentionInputIDs.WEIGHTS],
|
|
197
|
+
attention.input[AttentionInputIDs.BIAS],
|
|
198
|
+
token_offset,
|
|
199
|
+
cumulative_sequence_length,
|
|
200
|
+
attention_bias,
|
|
201
|
+
],
|
|
202
|
+
outputs=[attention.output[AttentionOutputIDs.OUTPUT]],
|
|
203
|
+
name=self.model.create_node_name(Operators.PACKEDATTENTION),
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
attributes = []
|
|
207
|
+
for attr in attention.attribute:
|
|
208
|
+
if attr.name in ["num_heads", "qkv_hidden_sizes", "scale"]:
|
|
209
|
+
attributes.append(attr)
|
|
210
|
+
|
|
211
|
+
packed_attention.attribute.extend(attributes)
|
|
212
|
+
packed_attention.domain = "com.microsoft"
|
|
213
|
+
self.nodes_to_add.append(packed_attention)
|
|
214
|
+
self.nodes_to_remove.append(attention)
|
|
215
|
+
self.node_name_to_graph_name[packed_attention.name] = self.this_graph_name
|
|
216
|
+
|
|
217
|
+
logger.info("Converted %d Attention nodes to PackedAttention.", len(self.attention_nodes))
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
class PackingMultiHeadAttention(PackingAttentionBase):
|
|
221
|
+
def __init__(self, model: OnnxModel):
|
|
222
|
+
super().__init__(model, Operators.MULTI_HEAD_ATTENTION)
|
|
223
|
+
|
|
224
|
+
def _check_empty_input(self, node, index: int, name: str):
|
|
225
|
+
"""Check a node does not have given input."""
|
|
226
|
+
if len(node.input) > index:
|
|
227
|
+
if len(node.input[index]) > 0:
|
|
228
|
+
logger.error(f"node input {index} ({name}) is not supported in PackedMultiHeadAttention: {node}")
|
|
229
|
+
return False
|
|
230
|
+
return True
|
|
231
|
+
|
|
232
|
+
def _check_empty_output(self, node, index: int, name: str):
|
|
233
|
+
"""Check a node does not have given input."""
|
|
234
|
+
if len(node.output) > index:
|
|
235
|
+
if len(node.output[index]) > 0:
|
|
236
|
+
logger.error(f"node output {index} ({name}) is not supported in PackedMultiHeadAttention: {node}")
|
|
237
|
+
return False
|
|
238
|
+
return True
|
|
239
|
+
|
|
240
|
+
def _are_attentions_supported(self) -> bool:
|
|
241
|
+
for node in self.attention_nodes:
|
|
242
|
+
for attr in node.attribute:
|
|
243
|
+
if attr.name not in ["num_heads", "mask_filter_value", "scale"]:
|
|
244
|
+
logger.error(f"node attribute {attr.name} is not supported in PackedMultiHeadAttention: {node}")
|
|
245
|
+
return False
|
|
246
|
+
|
|
247
|
+
if node.input[MultiHeadAttentionInputIDs.KEY] and not node.input[MultiHeadAttentionInputIDs.VALUE]:
|
|
248
|
+
logger.error("packed kv format is not supported in PackedMultiHeadAttention")
|
|
249
|
+
return False
|
|
250
|
+
|
|
251
|
+
if not (
|
|
252
|
+
self._check_empty_input(node, MultiHeadAttentionInputIDs.PAST_KEY, "past_key")
|
|
253
|
+
and self._check_empty_input(node, MultiHeadAttentionInputIDs.PAST_VALUE, "past_key")
|
|
254
|
+
and self._check_empty_output(node, MultiHeadAttentionOutputIDs.PRESENT_KEY, "present_key")
|
|
255
|
+
and self._check_empty_output(node, MultiHeadAttentionOutputIDs.PRESENT_VALUE, "present_key")
|
|
256
|
+
):
|
|
257
|
+
return False
|
|
258
|
+
|
|
259
|
+
return True
|
|
260
|
+
|
|
261
|
+
def _replace_attention_with_packing_attention(self, token_offset: str, cumulative_sequence_length: str) -> None:
|
|
262
|
+
gated_relative_pos_bias_count = 0
|
|
263
|
+
for mha in self.attention_nodes:
|
|
264
|
+
attention_bias = (
|
|
265
|
+
mha.input[MultiHeadAttentionInputIDs.ATTENTION_BIAS]
|
|
266
|
+
if len(mha.input) > MultiHeadAttentionInputIDs.ATTENTION_BIAS
|
|
267
|
+
else ""
|
|
268
|
+
)
|
|
269
|
+
packed_mha = helper.make_node(
|
|
270
|
+
Operators.PACKED_MULTI_HEAD_ATTENTION,
|
|
271
|
+
inputs=[
|
|
272
|
+
mha.input[MultiHeadAttentionInputIDs.QUERY],
|
|
273
|
+
mha.input[MultiHeadAttentionInputIDs.KEY],
|
|
274
|
+
mha.input[MultiHeadAttentionInputIDs.VALUE],
|
|
275
|
+
mha.input[MultiHeadAttentionInputIDs.BIAS],
|
|
276
|
+
token_offset,
|
|
277
|
+
cumulative_sequence_length,
|
|
278
|
+
attention_bias,
|
|
279
|
+
],
|
|
280
|
+
outputs=[mha.output[MultiHeadAttentionOutputIDs.OUTPUT]],
|
|
281
|
+
name=self.model.create_node_name(Operators.PACKED_MULTI_HEAD_ATTENTION),
|
|
282
|
+
)
|
|
283
|
+
|
|
284
|
+
attributes = []
|
|
285
|
+
for attr in mha.attribute:
|
|
286
|
+
if attr.name in ["num_heads", "mask_filter_value", "scale"]:
|
|
287
|
+
attributes.append(attr)
|
|
288
|
+
|
|
289
|
+
packed_mha.attribute.extend(attributes)
|
|
290
|
+
packed_mha.domain = "com.microsoft"
|
|
291
|
+
self.nodes_to_add.append(packed_mha)
|
|
292
|
+
self.nodes_to_remove.append(mha)
|
|
293
|
+
self.node_name_to_graph_name[packed_mha.name] = self.this_graph_name
|
|
294
|
+
|
|
295
|
+
# Append token_offset input to GatedRelativePositionBias
|
|
296
|
+
if attention_bias:
|
|
297
|
+
rel_pos_bias_node = self.model.get_parent(mha, MultiHeadAttentionInputIDs.ATTENTION_BIAS)
|
|
298
|
+
if (
|
|
299
|
+
rel_pos_bias_node
|
|
300
|
+
and rel_pos_bias_node.op_type == "GatedRelativePositionBias"
|
|
301
|
+
and len(rel_pos_bias_node.input) == 6
|
|
302
|
+
):
|
|
303
|
+
rel_pos_bias_node.input.append(token_offset)
|
|
304
|
+
gated_relative_pos_bias_count += 1
|
|
305
|
+
|
|
306
|
+
logger.info("Converted %d MultiHeadAttention nodes to PackedMultiHeadAttention.", len(self.attention_nodes))
|
|
307
|
+
logger.info("Converted %d GatedRelativePositionBias nodes to packing mode.", gated_relative_pos_bias_count)
|
|
308
|
+
|
|
309
|
+
def _get_input_to_remove_padding(self, first_attention_node) -> Union[str, None]:
|
|
310
|
+
# When there are query, key and value inputs, we need to find the first input of the parent MatMul node.
|
|
311
|
+
matmul = self.model.get_parent(first_attention_node, 0)
|
|
312
|
+
if matmul and matmul.op_type == "MatMul":
|
|
313
|
+
return matmul.input[0]
|
|
314
|
+
return None
|
|
315
|
+
|
|
316
|
+
|
|
317
|
+
class PackingMode:
|
|
318
|
+
def __init__(self, model: OnnxModel):
|
|
319
|
+
self.model = model
|
|
320
|
+
|
|
321
|
+
def convert(self, use_symbolic_shape_infer: bool = True) -> None:
|
|
322
|
+
if self.model.get_nodes_by_op_type(Operators.ATTENTION):
|
|
323
|
+
if self.model.get_nodes_by_op_type(Operators.MULTI_HEAD_ATTENTION):
|
|
324
|
+
logger.error("Packing mode does not support both Attention and MultiHeadAttention in same graph.")
|
|
325
|
+
return None
|
|
326
|
+
packing = PackingAttention(self.model)
|
|
327
|
+
return packing.convert(use_symbolic_shape_infer)
|
|
328
|
+
elif self.model.get_nodes_by_op_type(Operators.MULTI_HEAD_ATTENTION):
|
|
329
|
+
packing = PackingMultiHeadAttention(self.model)
|
|
330
|
+
return packing.convert(use_symbolic_shape_infer)
|
|
331
|
+
else:
|
|
332
|
+
logger.error("Packing mode requires either Attention or MultiHeadAttention node in onnx graph.")
|
|
333
|
+
return None
|
|
334
|
+
|
|
335
|
+
|
|
336
|
+
def _parse_arguments():
|
|
337
|
+
parser = argparse.ArgumentParser(
|
|
338
|
+
description="Convert to packing mode tool for ONNX Runtime. It converts BERT like model to use packing mode."
|
|
339
|
+
)
|
|
340
|
+
parser.add_argument("--input", required=True, type=str, help="input onnx model path")
|
|
341
|
+
|
|
342
|
+
parser.add_argument("--output", required=True, type=str, help="optimized onnx model path")
|
|
343
|
+
|
|
344
|
+
parser.add_argument("--verbose", required=False, action="store_true", help="show debug information.")
|
|
345
|
+
parser.set_defaults(verbose=False)
|
|
346
|
+
|
|
347
|
+
parser.add_argument(
|
|
348
|
+
"--use_external_data_format",
|
|
349
|
+
required=False,
|
|
350
|
+
action="store_true",
|
|
351
|
+
help="use external data format to store large model (>2GB)",
|
|
352
|
+
)
|
|
353
|
+
parser.set_defaults(use_external_data_format=False)
|
|
354
|
+
|
|
355
|
+
args = parser.parse_args()
|
|
356
|
+
|
|
357
|
+
return args
|
|
358
|
+
|
|
359
|
+
|
|
360
|
+
def _setup_logger(verbose):
|
|
361
|
+
if verbose:
|
|
362
|
+
coloredlogs.install(
|
|
363
|
+
level="DEBUG",
|
|
364
|
+
fmt="[%(filename)s:%(lineno)s - %(funcName)20s()] %(message)s",
|
|
365
|
+
)
|
|
366
|
+
else:
|
|
367
|
+
coloredlogs.install(fmt="%(funcName)20s: %(message)s")
|
|
368
|
+
|
|
369
|
+
|
|
370
|
+
def main():
|
|
371
|
+
args = _parse_arguments()
|
|
372
|
+
|
|
373
|
+
_setup_logger(args.verbose)
|
|
374
|
+
|
|
375
|
+
logger.debug(f"arguments:{args}")
|
|
376
|
+
|
|
377
|
+
if os.path.realpath(args.input) == os.path.realpath(args.output):
|
|
378
|
+
logger.warning("Specified the same input and output path. Note that this may overwrite the original model")
|
|
379
|
+
|
|
380
|
+
model = load_model(args.input)
|
|
381
|
+
packing_mode = PackingMode(OnnxModel(model))
|
|
382
|
+
packing_mode.convert()
|
|
383
|
+
packing_mode.model.save_model_to_file(args.output, use_external_data_format=args.use_external_data_format)
|
|
384
|
+
|
|
385
|
+
|
|
386
|
+
if __name__ == "__main__":
|
|
387
|
+
main()
|
|
@@ -0,0 +1,104 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License.
|
|
4
|
+
# --------------------------------------------------------------------------
|
|
5
|
+
import logging
|
|
6
|
+
|
|
7
|
+
import onnx
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class DynamoOnnxHelper:
|
|
11
|
+
"""
|
|
12
|
+
Helper class for processing ONNX models exported by torch Dynamo.
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
def __init__(self, model: onnx.ModelProto):
|
|
16
|
+
self.model = model
|
|
17
|
+
|
|
18
|
+
def update_edges(self, edge_mapping: dict) -> None:
|
|
19
|
+
"""
|
|
20
|
+
Updates the edges in the model according to the given mapping.
|
|
21
|
+
"""
|
|
22
|
+
for node in self.model.graph.node:
|
|
23
|
+
for i in range(len(node.input)):
|
|
24
|
+
if node.input[i] in edge_mapping:
|
|
25
|
+
node.input[i] = edge_mapping[node.input[i]]
|
|
26
|
+
for i in range(len(node.output)):
|
|
27
|
+
if node.output[i] in edge_mapping:
|
|
28
|
+
node.output[i] = edge_mapping[node.output[i]]
|
|
29
|
+
|
|
30
|
+
for graph_input in self.model.graph.input:
|
|
31
|
+
if graph_input.name in edge_mapping:
|
|
32
|
+
graph_input.name = edge_mapping[graph_input.name]
|
|
33
|
+
for graph_output in self.model.graph.output:
|
|
34
|
+
if graph_output.name in edge_mapping:
|
|
35
|
+
graph_output.name = edge_mapping[graph_output.name]
|
|
36
|
+
|
|
37
|
+
def unroll_function(self, func_name: str) -> None:
|
|
38
|
+
"""
|
|
39
|
+
Unrolls the function with the given name in the model.
|
|
40
|
+
"""
|
|
41
|
+
logging.info(f"Unrolling function {func_name}...")
|
|
42
|
+
nodes_to_remove = []
|
|
43
|
+
nodes_to_add = []
|
|
44
|
+
edges_to_remove = []
|
|
45
|
+
edges_to_add = []
|
|
46
|
+
for node in self.model.graph.node:
|
|
47
|
+
if node.op_type == func_name:
|
|
48
|
+
nodes_to_remove.append(node)
|
|
49
|
+
edges_to_remove.extend(list(node.input) + list(node.output))
|
|
50
|
+
|
|
51
|
+
func_to_remove = None
|
|
52
|
+
for f in self.model.functions:
|
|
53
|
+
if f.name == func_name:
|
|
54
|
+
nodes_to_add.extend(list(f.node))
|
|
55
|
+
edges_to_add.extend(list(f.input) + list(f.output))
|
|
56
|
+
func_to_remove = f
|
|
57
|
+
|
|
58
|
+
assert len(edges_to_remove) == len(edges_to_add)
|
|
59
|
+
|
|
60
|
+
for node in nodes_to_remove:
|
|
61
|
+
self.model.graph.node.remove(node)
|
|
62
|
+
for node in nodes_to_add:
|
|
63
|
+
self.model.graph.node.append(node)
|
|
64
|
+
if func_to_remove is not None:
|
|
65
|
+
self.model.functions.remove(func_to_remove)
|
|
66
|
+
|
|
67
|
+
edge_mapping = {}
|
|
68
|
+
for i in range(len(edges_to_remove)):
|
|
69
|
+
k = edges_to_remove[i]
|
|
70
|
+
v = edges_to_add[i]
|
|
71
|
+
if k != v:
|
|
72
|
+
edge_mapping[k] = v
|
|
73
|
+
|
|
74
|
+
return self.update_edges(edge_mapping)
|
|
75
|
+
|
|
76
|
+
def remove_function(self, func_name: str, input_id: int, output_id: int) -> None:
|
|
77
|
+
"""
|
|
78
|
+
Removes the function in the model.
|
|
79
|
+
"""
|
|
80
|
+
edge_mapping = {}
|
|
81
|
+
nodes_to_remove = []
|
|
82
|
+
for node in self.model.graph.node:
|
|
83
|
+
if node.op_type.find(func_name) != -1:
|
|
84
|
+
edge_mapping[node.input[input_id]] = node.output[output_id]
|
|
85
|
+
nodes_to_remove.append(node)
|
|
86
|
+
for node in nodes_to_remove:
|
|
87
|
+
self.model.graph.node.remove(node)
|
|
88
|
+
|
|
89
|
+
self.update_edges(edge_mapping)
|
|
90
|
+
|
|
91
|
+
def remove_dropout_layer(self) -> None:
|
|
92
|
+
"""
|
|
93
|
+
Removes the dropout layer in the model.
|
|
94
|
+
"""
|
|
95
|
+
logging.info("Removing dropout layer...")
|
|
96
|
+
self.remove_function("Dropout", 0, 0)
|
|
97
|
+
|
|
98
|
+
def remove_lm_head_layer(self) -> None:
|
|
99
|
+
"""
|
|
100
|
+
Removes the LM head layer in the model.
|
|
101
|
+
"""
|
|
102
|
+
logging.info("Removing LM head layer...")
|
|
103
|
+
# bugbug: need to copy the right vi over
|
|
104
|
+
self.remove_function("Linear_lm_head", 2, 0)
|