onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxruntime/LICENSE +21 -0
- onnxruntime/Privacy.md +21 -0
- onnxruntime/ThirdPartyNotices.txt +6121 -0
- onnxruntime/__init__.py +418 -0
- onnxruntime/backend/__init__.py +6 -0
- onnxruntime/backend/backend.py +175 -0
- onnxruntime/backend/backend_rep.py +52 -0
- onnxruntime/capi/DirectML.dll +0 -0
- onnxruntime/capi/__init__.py +4 -0
- onnxruntime/capi/_ld_preload.py +7 -0
- onnxruntime/capi/_pybind_state.py +33 -0
- onnxruntime/capi/build_and_package_info.py +2 -0
- onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
- onnxruntime/capi/onnxruntime.dll +0 -0
- onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
- onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
- onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
- onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
- onnxruntime/capi/onnxruntime_validation.py +154 -0
- onnxruntime/capi/version_info.py +2 -0
- onnxruntime/datasets/__init__.py +18 -0
- onnxruntime/datasets/logreg_iris.onnx +0 -0
- onnxruntime/datasets/mul_1.onnx +0 -0
- onnxruntime/datasets/sigmoid.onnx +13 -0
- onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
- onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
- onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
- onnxruntime/quantization/__init__.py +19 -0
- onnxruntime/quantization/base_quantizer.py +529 -0
- onnxruntime/quantization/calibrate.py +1267 -0
- onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
- onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
- onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
- onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
- onnxruntime/quantization/fusions/__init__.py +4 -0
- onnxruntime/quantization/fusions/fusion.py +311 -0
- onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
- onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
- onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
- onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
- onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
- onnxruntime/quantization/neural_compressor/__init__.py +1 -0
- onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
- onnxruntime/quantization/neural_compressor/util.py +80 -0
- onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
- onnxruntime/quantization/onnx_model.py +600 -0
- onnxruntime/quantization/onnx_quantizer.py +1163 -0
- onnxruntime/quantization/operators/__init__.py +2 -0
- onnxruntime/quantization/operators/activation.py +119 -0
- onnxruntime/quantization/operators/argmax.py +18 -0
- onnxruntime/quantization/operators/attention.py +73 -0
- onnxruntime/quantization/operators/base_operator.py +26 -0
- onnxruntime/quantization/operators/binary_op.py +72 -0
- onnxruntime/quantization/operators/concat.py +62 -0
- onnxruntime/quantization/operators/conv.py +260 -0
- onnxruntime/quantization/operators/direct_q8.py +78 -0
- onnxruntime/quantization/operators/embed_layernorm.py +121 -0
- onnxruntime/quantization/operators/gather.py +64 -0
- onnxruntime/quantization/operators/gavgpool.py +62 -0
- onnxruntime/quantization/operators/gemm.py +172 -0
- onnxruntime/quantization/operators/lstm.py +121 -0
- onnxruntime/quantization/operators/matmul.py +231 -0
- onnxruntime/quantization/operators/maxpool.py +34 -0
- onnxruntime/quantization/operators/norm.py +40 -0
- onnxruntime/quantization/operators/pad.py +172 -0
- onnxruntime/quantization/operators/pooling.py +67 -0
- onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
- onnxruntime/quantization/operators/resize.py +34 -0
- onnxruntime/quantization/operators/softmax.py +74 -0
- onnxruntime/quantization/operators/split.py +63 -0
- onnxruntime/quantization/operators/where.py +87 -0
- onnxruntime/quantization/preprocess.py +141 -0
- onnxruntime/quantization/qdq_loss_debug.py +389 -0
- onnxruntime/quantization/qdq_quantizer.py +1477 -0
- onnxruntime/quantization/quant_utils.py +1051 -0
- onnxruntime/quantization/quantize.py +953 -0
- onnxruntime/quantization/registry.py +110 -0
- onnxruntime/quantization/shape_inference.py +204 -0
- onnxruntime/quantization/static_quantize_runner.py +256 -0
- onnxruntime/quantization/tensor_quant_overrides.py +520 -0
- onnxruntime/tools/__init__.py +10 -0
- onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
- onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
- onnxruntime/tools/file_utils.py +47 -0
- onnxruntime/tools/logger.py +11 -0
- onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
- onnxruntime/tools/mobile_helpers/__init__.py +0 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
- onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
- onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
- onnxruntime/tools/offline_tuning.py +169 -0
- onnxruntime/tools/onnx_model_utils.py +416 -0
- onnxruntime/tools/onnx_randomizer.py +85 -0
- onnxruntime/tools/onnxruntime_test.py +164 -0
- onnxruntime/tools/optimize_onnx_model.py +56 -0
- onnxruntime/tools/ort_format_model/__init__.py +27 -0
- onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
- onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
- onnxruntime/tools/ort_format_model/types.py +85 -0
- onnxruntime/tools/ort_format_model/utils.py +61 -0
- onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
- onnxruntime/tools/pytorch_export_helpers.py +131 -0
- onnxruntime/tools/qdq_helpers/__init__.py +0 -0
- onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
- onnxruntime/tools/qnn/add_trans_cast.py +292 -0
- onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
- onnxruntime/tools/qnn/preprocess.py +165 -0
- onnxruntime/tools/reduced_build_config_parser.py +203 -0
- onnxruntime/tools/remove_initializer_from_input.py +37 -0
- onnxruntime/tools/symbolic_shape_infer.py +3094 -0
- onnxruntime/tools/update_onnx_opset.py +31 -0
- onnxruntime/transformers/__init__.py +8 -0
- onnxruntime/transformers/affinity_helper.py +40 -0
- onnxruntime/transformers/benchmark.py +942 -0
- onnxruntime/transformers/benchmark_helper.py +643 -0
- onnxruntime/transformers/bert_perf_test.py +629 -0
- onnxruntime/transformers/bert_test_data.py +641 -0
- onnxruntime/transformers/compare_bert_results.py +256 -0
- onnxruntime/transformers/constants.py +47 -0
- onnxruntime/transformers/convert_generation.py +3605 -0
- onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
- onnxruntime/transformers/convert_to_packing_mode.py +385 -0
- onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
- onnxruntime/transformers/float16.py +501 -0
- onnxruntime/transformers/fusion_attention.py +1189 -0
- onnxruntime/transformers/fusion_attention_clip.py +340 -0
- onnxruntime/transformers/fusion_attention_sam2.py +533 -0
- onnxruntime/transformers/fusion_attention_unet.py +1307 -0
- onnxruntime/transformers/fusion_attention_vae.py +300 -0
- onnxruntime/transformers/fusion_bart_attention.py +435 -0
- onnxruntime/transformers/fusion_base.py +141 -0
- onnxruntime/transformers/fusion_bias_add.py +57 -0
- onnxruntime/transformers/fusion_biasgelu.py +66 -0
- onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
- onnxruntime/transformers/fusion_conformer_attention.py +222 -0
- onnxruntime/transformers/fusion_constant_fold.py +144 -0
- onnxruntime/transformers/fusion_embedlayer.py +810 -0
- onnxruntime/transformers/fusion_fastgelu.py +492 -0
- onnxruntime/transformers/fusion_gelu.py +258 -0
- onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
- onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
- onnxruntime/transformers/fusion_gpt_attention.py +546 -0
- onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
- onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
- onnxruntime/transformers/fusion_group_norm.py +180 -0
- onnxruntime/transformers/fusion_layernorm.py +489 -0
- onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
- onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
- onnxruntime/transformers/fusion_options.py +340 -0
- onnxruntime/transformers/fusion_qordered_attention.py +420 -0
- onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
- onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
- onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
- onnxruntime/transformers/fusion_quickgelu.py +74 -0
- onnxruntime/transformers/fusion_reshape.py +173 -0
- onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
- onnxruntime/transformers/fusion_shape.py +109 -0
- onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
- onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
- onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
- onnxruntime/transformers/fusion_transpose.py +167 -0
- onnxruntime/transformers/fusion_utils.py +321 -0
- onnxruntime/transformers/huggingface_models.py +74 -0
- onnxruntime/transformers/import_utils.py +20 -0
- onnxruntime/transformers/io_binding_helper.py +487 -0
- onnxruntime/transformers/large_model_exporter.py +395 -0
- onnxruntime/transformers/machine_info.py +230 -0
- onnxruntime/transformers/metrics.py +163 -0
- onnxruntime/transformers/models/bart/__init__.py +12 -0
- onnxruntime/transformers/models/bart/export.py +98 -0
- onnxruntime/transformers/models/bert/__init__.py +12 -0
- onnxruntime/transformers/models/bert/eval_squad.py +329 -0
- onnxruntime/transformers/models/gpt2/__init__.py +12 -0
- onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
- onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
- onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
- onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
- onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
- onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
- onnxruntime/transformers/models/llama/__init__.py +12 -0
- onnxruntime/transformers/models/llama/benchmark.py +700 -0
- onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
- onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
- onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
- onnxruntime/transformers/models/llama/dist_settings.py +57 -0
- onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
- onnxruntime/transformers/models/llama/llama_parity.py +343 -0
- onnxruntime/transformers/models/llama/llama_torch.py +47 -0
- onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
- onnxruntime/transformers/models/longformer/__init__.py +12 -0
- onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
- onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
- onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
- onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
- onnxruntime/transformers/models/phi2/__init__.py +12 -0
- onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
- onnxruntime/transformers/models/phi2/inference_example.py +414 -0
- onnxruntime/transformers/models/sam2/__init__.py +12 -0
- onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
- onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
- onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
- onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
- onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
- onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
- onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
- onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
- onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
- onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
- onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
- onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
- onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
- onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
- onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
- onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
- onnxruntime/transformers/models/t5/__init__.py +12 -0
- onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
- onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
- onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
- onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
- onnxruntime/transformers/models/t5/t5_helper.py +302 -0
- onnxruntime/transformers/models/whisper/__init__.py +12 -0
- onnxruntime/transformers/models/whisper/benchmark.py +585 -0
- onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
- onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
- onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
- onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
- onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
- onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
- onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
- onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
- onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
- onnxruntime/transformers/onnx_exporter.py +719 -0
- onnxruntime/transformers/onnx_model.py +1636 -0
- onnxruntime/transformers/onnx_model_bart.py +141 -0
- onnxruntime/transformers/onnx_model_bert.py +488 -0
- onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
- onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
- onnxruntime/transformers/onnx_model_clip.py +42 -0
- onnxruntime/transformers/onnx_model_conformer.py +32 -0
- onnxruntime/transformers/onnx_model_gpt2.py +101 -0
- onnxruntime/transformers/onnx_model_mmdit.py +112 -0
- onnxruntime/transformers/onnx_model_phi.py +929 -0
- onnxruntime/transformers/onnx_model_sam2.py +137 -0
- onnxruntime/transformers/onnx_model_t5.py +985 -0
- onnxruntime/transformers/onnx_model_tnlr.py +226 -0
- onnxruntime/transformers/onnx_model_unet.py +258 -0
- onnxruntime/transformers/onnx_model_vae.py +42 -0
- onnxruntime/transformers/onnx_utils.py +55 -0
- onnxruntime/transformers/optimizer.py +620 -0
- onnxruntime/transformers/past_helper.py +149 -0
- onnxruntime/transformers/profile_result_processor.py +358 -0
- onnxruntime/transformers/profiler.py +434 -0
- onnxruntime/transformers/quantize_helper.py +76 -0
- onnxruntime/transformers/shape_infer_helper.py +121 -0
- onnxruntime/transformers/shape_optimizer.py +400 -0
- onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
- onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
- onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
- onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
- onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
- onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License.
|
|
4
|
+
# --------------------------------------------------------------------------
|
|
5
|
+
import logging
|
|
6
|
+
|
|
7
|
+
from diffusion_models import PipelineInfo
|
|
8
|
+
from engine_builder import EngineBuilder, EngineType
|
|
9
|
+
|
|
10
|
+
logger = logging.getLogger(__name__)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class TorchEngineBuilder(EngineBuilder):
|
|
14
|
+
def __init__(
|
|
15
|
+
self,
|
|
16
|
+
pipeline_info: PipelineInfo,
|
|
17
|
+
max_batch_size=16,
|
|
18
|
+
device="cuda",
|
|
19
|
+
use_cuda_graph=False,
|
|
20
|
+
):
|
|
21
|
+
"""
|
|
22
|
+
Initializes the ONNX Runtime TensorRT ExecutionProvider Engine Builder.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
pipeline_info (PipelineInfo):
|
|
26
|
+
Version and Type of pipeline.
|
|
27
|
+
max_batch_size (int):
|
|
28
|
+
Maximum batch size for dynamic batch engine.
|
|
29
|
+
device (str):
|
|
30
|
+
device to run.
|
|
31
|
+
use_cuda_graph (bool):
|
|
32
|
+
Use CUDA graph to capture engine execution and then launch inference
|
|
33
|
+
"""
|
|
34
|
+
super().__init__(
|
|
35
|
+
EngineType.TORCH,
|
|
36
|
+
pipeline_info,
|
|
37
|
+
max_batch_size=max_batch_size,
|
|
38
|
+
device=device,
|
|
39
|
+
use_cuda_graph=use_cuda_graph,
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
self.compile_config = {}
|
|
43
|
+
if use_cuda_graph:
|
|
44
|
+
self.compile_config = {
|
|
45
|
+
"clip": {"mode": "reduce-overhead", "dynamic": False},
|
|
46
|
+
"clip2": {"mode": "reduce-overhead", "dynamic": False},
|
|
47
|
+
"unet": {"mode": "reduce-overhead", "fullgraph": True, "dynamic": False},
|
|
48
|
+
"unetxl": {"mode": "reduce-overhead", "fullgraph": True, "dynamic": False},
|
|
49
|
+
"vae": {"mode": "reduce-overhead", "fullgraph": False, "dynamic": False},
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
def build_engines(
|
|
53
|
+
self,
|
|
54
|
+
framework_model_dir: str,
|
|
55
|
+
):
|
|
56
|
+
import torch # noqa: PLC0415
|
|
57
|
+
|
|
58
|
+
self.torch_device = torch.device("cuda", torch.cuda.current_device())
|
|
59
|
+
self.load_models(framework_model_dir)
|
|
60
|
+
|
|
61
|
+
pipe = self.load_pipeline_with_lora() if self.pipeline_info.lora_weights else None
|
|
62
|
+
|
|
63
|
+
built_engines = {}
|
|
64
|
+
for model_name, model_obj in self.models.items():
|
|
65
|
+
model = self.get_or_load_model(pipe, model_name, model_obj, framework_model_dir)
|
|
66
|
+
if self.pipeline_info.is_xl() and not self.custom_fp16_vae:
|
|
67
|
+
model = model.to(device=self.torch_device, dtype=torch.float32)
|
|
68
|
+
else:
|
|
69
|
+
model = model.to(device=self.torch_device, dtype=torch.float16)
|
|
70
|
+
|
|
71
|
+
if model_name in self.compile_config:
|
|
72
|
+
compile_config = self.compile_config[model_name]
|
|
73
|
+
if model_name in ["unet", "unetxl"]:
|
|
74
|
+
model.to(memory_format=torch.channels_last)
|
|
75
|
+
engine = torch.compile(model, **compile_config)
|
|
76
|
+
built_engines[model_name] = engine
|
|
77
|
+
else: # eager mode
|
|
78
|
+
built_engines[model_name] = model
|
|
79
|
+
|
|
80
|
+
self.engines = built_engines
|
|
81
|
+
|
|
82
|
+
def run_engine(self, model_name, feed_dict):
|
|
83
|
+
if model_name in ["unet", "unetxl"]:
|
|
84
|
+
if "controlnet_images" in feed_dict:
|
|
85
|
+
return {"latent": self.engines[model_name](**feed_dict)}
|
|
86
|
+
|
|
87
|
+
if model_name == "unetxl":
|
|
88
|
+
added_cond_kwargs = {k: feed_dict[k] for k in feed_dict if k in ["text_embeds", "time_ids"]}
|
|
89
|
+
return {
|
|
90
|
+
"latent": self.engines[model_name](
|
|
91
|
+
feed_dict["sample"],
|
|
92
|
+
feed_dict["timestep"],
|
|
93
|
+
feed_dict["encoder_hidden_states"],
|
|
94
|
+
added_cond_kwargs=added_cond_kwargs,
|
|
95
|
+
return_dict=False,
|
|
96
|
+
)[0]
|
|
97
|
+
}
|
|
98
|
+
|
|
99
|
+
return {
|
|
100
|
+
"latent": self.engines[model_name](
|
|
101
|
+
feed_dict["sample"], feed_dict["timestep"], feed_dict["encoder_hidden_states"], return_dict=False
|
|
102
|
+
)[0]
|
|
103
|
+
}
|
|
104
|
+
|
|
105
|
+
if model_name in ["vae_encoder"]:
|
|
106
|
+
return {"latent": self.engines[model_name](feed_dict["images"])}
|
|
107
|
+
|
|
108
|
+
raise RuntimeError(f"Shall not reach here: {model_name}")
|
|
@@ -0,0 +1,590 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License.
|
|
4
|
+
# --------------------------------------------------------------------------
|
|
5
|
+
#
|
|
6
|
+
# This script converts stable diffusion onnx models from float to half (mixed) precision for GPU inference.
|
|
7
|
+
#
|
|
8
|
+
# Before running this script, follow README.md to setup python environment and convert stable diffusion checkpoint
|
|
9
|
+
# to float32 onnx models.
|
|
10
|
+
#
|
|
11
|
+
# For example, the float32 ONNX pipeline is saved to ./sd-v1-5 directory, you can optimize and convert it to float16
|
|
12
|
+
# like the following:
|
|
13
|
+
# python optimize_pipeline.py -i ./sd-v1-5 -o ./sd-v1-5-fp16 --float16
|
|
14
|
+
#
|
|
15
|
+
# Note that the optimizations are carried out for CUDA Execution Provider at first, other EPs may not have the support
|
|
16
|
+
# for the fused operators. The users could disable the operator fusion manually to workaround.
|
|
17
|
+
|
|
18
|
+
import argparse
|
|
19
|
+
import logging
|
|
20
|
+
import os
|
|
21
|
+
import shutil
|
|
22
|
+
import tempfile
|
|
23
|
+
import warnings
|
|
24
|
+
from pathlib import Path
|
|
25
|
+
|
|
26
|
+
import onnx
|
|
27
|
+
from fusion_options import FusionOptions
|
|
28
|
+
from onnx_model_clip import ClipOnnxModel
|
|
29
|
+
from onnx_model_mmdit import MmditOnnxModel
|
|
30
|
+
from onnx_model_t5 import T5OnnxModel
|
|
31
|
+
from onnx_model_unet import UnetOnnxModel
|
|
32
|
+
from onnx_model_vae import VaeOnnxModel
|
|
33
|
+
from optimizer import optimize_by_onnxruntime, optimize_model
|
|
34
|
+
from packaging import version
|
|
35
|
+
|
|
36
|
+
import onnxruntime
|
|
37
|
+
|
|
38
|
+
logger = logging.getLogger(__name__)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def has_external_data(onnx_model_path):
|
|
42
|
+
original_model = onnx.load_model(str(onnx_model_path), load_external_data=False)
|
|
43
|
+
for initializer in original_model.graph.initializer:
|
|
44
|
+
if initializer.HasField("data_location") and initializer.data_location == onnx.TensorProto.EXTERNAL:
|
|
45
|
+
return True
|
|
46
|
+
return False
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def is_sd_3(source_dir: Path):
|
|
50
|
+
return (source_dir / "text_encoder_3").exists()
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def is_sdxl(source_dir: Path):
|
|
54
|
+
return (
|
|
55
|
+
(source_dir / "text_encoder_2").exists()
|
|
56
|
+
and not (source_dir / "text_encoder_3").exists()
|
|
57
|
+
and not (source_dir / "transformer").exists()
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def is_flux(source_dir: Path):
|
|
62
|
+
return (
|
|
63
|
+
(source_dir / "text_encoder_2").exists()
|
|
64
|
+
and not (source_dir / "text_encoder_3").exists()
|
|
65
|
+
and (source_dir / "transformer").exists()
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def _classify_pipeline_type(source_dir: Path):
|
|
70
|
+
# May also check _class_name in model_index.json like `StableDiffusion3Pipeline` or `FluxPipeline` etc to classify.
|
|
71
|
+
if is_sd_3(source_dir):
|
|
72
|
+
return "sd3"
|
|
73
|
+
|
|
74
|
+
if is_flux(source_dir):
|
|
75
|
+
return "flux"
|
|
76
|
+
|
|
77
|
+
if is_sdxl(source_dir):
|
|
78
|
+
return "sdxl"
|
|
79
|
+
|
|
80
|
+
# sd 1.x and 2.x
|
|
81
|
+
return "sd"
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def _get_model_list(pipeline_type: str):
|
|
85
|
+
if pipeline_type == "sd3":
|
|
86
|
+
return ["text_encoder", "text_encoder_2", "text_encoder_3", "transformer", "vae_encoder", "vae_decoder"]
|
|
87
|
+
|
|
88
|
+
if pipeline_type == "flux":
|
|
89
|
+
return ["text_encoder", "text_encoder_2", "transformer", "vae_encoder", "vae_decoder"]
|
|
90
|
+
|
|
91
|
+
if pipeline_type == "sdxl":
|
|
92
|
+
return ["text_encoder", "text_encoder_2", "unet", "vae_encoder", "vae_decoder"]
|
|
93
|
+
|
|
94
|
+
assert pipeline_type == "sd"
|
|
95
|
+
return ["text_encoder", "unet", "vae_encoder", "vae_decoder"]
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def _optimize_sd_pipeline(
|
|
99
|
+
source_dir: Path,
|
|
100
|
+
target_dir: Path,
|
|
101
|
+
pipeline_type: str,
|
|
102
|
+
model_list: list[str],
|
|
103
|
+
use_external_data_format: bool | None,
|
|
104
|
+
float16: bool,
|
|
105
|
+
bfloat16: bool,
|
|
106
|
+
force_fp32_ops: list[str],
|
|
107
|
+
enable_runtime_optimization: bool,
|
|
108
|
+
args,
|
|
109
|
+
):
|
|
110
|
+
"""Optimize onnx models used in stable diffusion onnx pipeline and optionally convert to float16.
|
|
111
|
+
|
|
112
|
+
Args:
|
|
113
|
+
source_dir (Path): Root of input directory of stable diffusion onnx pipeline with float32 models.
|
|
114
|
+
target_dir (Path): Root of output directory of stable diffusion onnx pipeline with optimized models.
|
|
115
|
+
model_list (List[str]): list of directory names with onnx model.
|
|
116
|
+
use_external_data_format (Optional[bool]): use external data format.
|
|
117
|
+
float16 (bool): use half precision
|
|
118
|
+
bfloat16 (bool): use bfloat16 as fallback if float16 is also provided.
|
|
119
|
+
force_fp32_ops(List[str]): operators that are forced to run in float32.
|
|
120
|
+
enable_runtime_optimization(bool): run graph optimization using Onnx Runtime.
|
|
121
|
+
|
|
122
|
+
Raises:
|
|
123
|
+
RuntimeError: input onnx model does not exist
|
|
124
|
+
RuntimeError: output onnx model path existed
|
|
125
|
+
"""
|
|
126
|
+
is_flux_pipeline = pipeline_type == "flux"
|
|
127
|
+
model_type_mapping = {
|
|
128
|
+
"transformer": "mmdit",
|
|
129
|
+
"unet": "unet",
|
|
130
|
+
"vae_encoder": "vae",
|
|
131
|
+
"vae_decoder": "vae",
|
|
132
|
+
"text_encoder": "clip",
|
|
133
|
+
"text_encoder_2": "t5" if is_flux_pipeline else "clip",
|
|
134
|
+
"text_encoder_3": "t5", # t5-v1_1-xxl is used in SD 3.x text_encoder_3 and Flux text_encoder_2.
|
|
135
|
+
"safety_checker": "unet",
|
|
136
|
+
}
|
|
137
|
+
|
|
138
|
+
model_type_class_mapping = {
|
|
139
|
+
"unet": UnetOnnxModel,
|
|
140
|
+
"vae": VaeOnnxModel,
|
|
141
|
+
"clip": ClipOnnxModel,
|
|
142
|
+
"t5": T5OnnxModel,
|
|
143
|
+
"mmdit": MmditOnnxModel,
|
|
144
|
+
}
|
|
145
|
+
|
|
146
|
+
force_fp32_operators = {
|
|
147
|
+
"unet": [],
|
|
148
|
+
"vae_encoder": [],
|
|
149
|
+
"vae_decoder": [],
|
|
150
|
+
"text_encoder": [],
|
|
151
|
+
"text_encoder_2": [],
|
|
152
|
+
"safety_checker": [],
|
|
153
|
+
"text_encoder_3": [],
|
|
154
|
+
"transformer": [],
|
|
155
|
+
}
|
|
156
|
+
|
|
157
|
+
# The node block list is generated by running the fp32 model and get statistics of node inputs and outputs.
|
|
158
|
+
# Nodes with any input or output of float or double data type, but value ouf of range of float16 are candidates.
|
|
159
|
+
# python optimize_pipeline.py -i ./flux1_schnell_onnx/fp32 -o ./flux1_schnell_onnx/fp32_opt
|
|
160
|
+
# export ORT_DEBUG_NODE_IO_DUMP_STATISTICS_DATA=1
|
|
161
|
+
# export ORT_DEBUG_NODE_IO_DUMP_INPUT_DATA=1
|
|
162
|
+
# export ORT_DEBUG_NODE_IO_DUMP_OUTPUT_DATA=1
|
|
163
|
+
# python benchmark.py --height 1024 --width 1024 --steps 4 -b 1 -v Flux.1S -p flux1_schnell_onnx/fp32_opt -e optimum >stdout.txt 2>stderr.txt
|
|
164
|
+
# Warning: The node name might change in different export settings. See benchmark_flux.sh for the settings.
|
|
165
|
+
flux_node_block_list = {
|
|
166
|
+
"text_encoder_2": [
|
|
167
|
+
"/encoder/block.10/layer.1/DenseReluDense/wo/MatMul",
|
|
168
|
+
"SkipLayerNorm_20",
|
|
169
|
+
"SkipLayerNorm_21",
|
|
170
|
+
"SkipLayerNorm_22",
|
|
171
|
+
"SkipLayerNorm_23",
|
|
172
|
+
"SkipLayerNorm_24",
|
|
173
|
+
"SkipLayerNorm_25",
|
|
174
|
+
"SkipLayerNorm_26",
|
|
175
|
+
"SkipLayerNorm_27",
|
|
176
|
+
"SkipLayerNorm_28",
|
|
177
|
+
"SkipLayerNorm_29",
|
|
178
|
+
"SkipLayerNorm_30",
|
|
179
|
+
"SkipLayerNorm_31",
|
|
180
|
+
"SkipLayerNorm_32",
|
|
181
|
+
"SkipLayerNorm_33",
|
|
182
|
+
"SkipLayerNorm_34",
|
|
183
|
+
"SkipLayerNorm_35",
|
|
184
|
+
"SkipLayerNorm_36",
|
|
185
|
+
"SkipLayerNorm_37",
|
|
186
|
+
"SkipLayerNorm_38",
|
|
187
|
+
"SkipLayerNorm_39",
|
|
188
|
+
"SkipLayerNorm_40",
|
|
189
|
+
"SkipLayerNorm_41",
|
|
190
|
+
"SkipLayerNorm_42",
|
|
191
|
+
"SkipLayerNorm_43",
|
|
192
|
+
"SkipLayerNorm_44",
|
|
193
|
+
"SkipLayerNorm_45",
|
|
194
|
+
"/encoder/block.23/layer.1/DenseReluDense/wo/MatMul",
|
|
195
|
+
"SkipLayerNorm_46",
|
|
196
|
+
],
|
|
197
|
+
"vae_decoder": [
|
|
198
|
+
"/decoder/mid_block/attentions.0/MatMul",
|
|
199
|
+
"/decoder/mid_block/attentions.0/Softmax",
|
|
200
|
+
],
|
|
201
|
+
"transformer": [
|
|
202
|
+
"/transformer_blocks.18/Mul_5",
|
|
203
|
+
"/transformer_blocks.18/Add_7",
|
|
204
|
+
"/Concat_1",
|
|
205
|
+
"LayerNorm_76",
|
|
206
|
+
"/single_transformer_blocks.0/Add",
|
|
207
|
+
"LayerNorm_77",
|
|
208
|
+
"/single_transformer_blocks.1/Add",
|
|
209
|
+
"LayerNorm_78",
|
|
210
|
+
"/single_transformer_blocks.2/Add",
|
|
211
|
+
"LayerNorm_79",
|
|
212
|
+
"/single_transformer_blocks.3/Add",
|
|
213
|
+
"LayerNorm_80",
|
|
214
|
+
"/single_transformer_blocks.4/Add",
|
|
215
|
+
"LayerNorm_81",
|
|
216
|
+
"/single_transformer_blocks.5/Add",
|
|
217
|
+
"LayerNorm_82",
|
|
218
|
+
"/single_transformer_blocks.6/Add",
|
|
219
|
+
"LayerNorm_83",
|
|
220
|
+
"/single_transformer_blocks.7/Add",
|
|
221
|
+
"LayerNorm_84",
|
|
222
|
+
"/single_transformer_blocks.8/Add",
|
|
223
|
+
"LayerNorm_85",
|
|
224
|
+
"/single_transformer_blocks.9/Add",
|
|
225
|
+
"LayerNorm_86",
|
|
226
|
+
"/single_transformer_blocks.10/Add",
|
|
227
|
+
"LayerNorm_87",
|
|
228
|
+
"/single_transformer_blocks.11/Add",
|
|
229
|
+
"LayerNorm_88",
|
|
230
|
+
"/single_transformer_blocks.12/Add",
|
|
231
|
+
"LayerNorm_89",
|
|
232
|
+
"/single_transformer_blocks.13/Add",
|
|
233
|
+
"LayerNorm_90",
|
|
234
|
+
"/single_transformer_blocks.14/Add",
|
|
235
|
+
"LayerNorm_91",
|
|
236
|
+
"/single_transformer_blocks.15/Add",
|
|
237
|
+
"LayerNorm_92",
|
|
238
|
+
"/single_transformer_blocks.16/Add",
|
|
239
|
+
"LayerNorm_93",
|
|
240
|
+
"/single_transformer_blocks.17/Add",
|
|
241
|
+
"LayerNorm_94",
|
|
242
|
+
"/single_transformer_blocks.18/Add",
|
|
243
|
+
"LayerNorm_95",
|
|
244
|
+
"/single_transformer_blocks.19/Add",
|
|
245
|
+
"LayerNorm_96",
|
|
246
|
+
"/single_transformer_blocks.20/Add",
|
|
247
|
+
"LayerNorm_97",
|
|
248
|
+
"/single_transformer_blocks.21/Add",
|
|
249
|
+
"LayerNorm_98",
|
|
250
|
+
"/single_transformer_blocks.22/Add",
|
|
251
|
+
"LayerNorm_99",
|
|
252
|
+
"/single_transformer_blocks.23/Add",
|
|
253
|
+
"LayerNorm_100",
|
|
254
|
+
"/single_transformer_blocks.24/Add",
|
|
255
|
+
"LayerNorm_101",
|
|
256
|
+
"/single_transformer_blocks.25/Add",
|
|
257
|
+
"LayerNorm_102",
|
|
258
|
+
"/single_transformer_blocks.26/Add",
|
|
259
|
+
"LayerNorm_103",
|
|
260
|
+
"/single_transformer_blocks.27/Add",
|
|
261
|
+
"LayerNorm_104",
|
|
262
|
+
"/single_transformer_blocks.28/Add",
|
|
263
|
+
"LayerNorm_105",
|
|
264
|
+
"/single_transformer_blocks.29/Add",
|
|
265
|
+
"LayerNorm_106",
|
|
266
|
+
"/single_transformer_blocks.30/Add",
|
|
267
|
+
"LayerNorm_107",
|
|
268
|
+
"/single_transformer_blocks.31/Add",
|
|
269
|
+
"LayerNorm_108",
|
|
270
|
+
"/single_transformer_blocks.32/Add",
|
|
271
|
+
"LayerNorm_109",
|
|
272
|
+
"/single_transformer_blocks.33/Add",
|
|
273
|
+
"LayerNorm_110",
|
|
274
|
+
"/single_transformer_blocks.34/Add",
|
|
275
|
+
"LayerNorm_111",
|
|
276
|
+
"/single_transformer_blocks.35/Add",
|
|
277
|
+
"LayerNorm_112",
|
|
278
|
+
"/single_transformer_blocks.36/Add",
|
|
279
|
+
"LayerNorm_113",
|
|
280
|
+
"/single_transformer_blocks.37/Add",
|
|
281
|
+
"/Shape",
|
|
282
|
+
"/Slice",
|
|
283
|
+
],
|
|
284
|
+
}
|
|
285
|
+
|
|
286
|
+
sd3_node_block_list = {"text_encoder_3": flux_node_block_list["text_encoder_2"]}
|
|
287
|
+
|
|
288
|
+
if force_fp32_ops:
|
|
289
|
+
for fp32_operator in force_fp32_ops:
|
|
290
|
+
parts = fp32_operator.split(":")
|
|
291
|
+
if len(parts) == 2 and parts[0] in force_fp32_operators and (parts[1] and parts[1][0].isupper()):
|
|
292
|
+
force_fp32_operators[parts[0]].append(parts[1])
|
|
293
|
+
else:
|
|
294
|
+
raise ValueError(
|
|
295
|
+
f"--force_fp32_ops shall be in the format of module:operator like unet:Attention, got {fp32_operator}"
|
|
296
|
+
)
|
|
297
|
+
|
|
298
|
+
op_counters = {}
|
|
299
|
+
for name, model_type in model_type_mapping.items():
|
|
300
|
+
onnx_model_path = source_dir / name / "model.onnx"
|
|
301
|
+
if not os.path.exists(onnx_model_path):
|
|
302
|
+
if name != "safety_checker" and name in model_list:
|
|
303
|
+
logger.warning("input onnx model does not exist: %s", onnx_model_path)
|
|
304
|
+
# some model are optional so we do not raise error here.
|
|
305
|
+
continue
|
|
306
|
+
|
|
307
|
+
# Prepare output directory
|
|
308
|
+
optimized_model_path = target_dir / name / "model.onnx"
|
|
309
|
+
if os.path.exists(optimized_model_path):
|
|
310
|
+
if not args.overwrite:
|
|
311
|
+
logger.warning("Skipped optimization since the target file existed: %s", optimized_model_path)
|
|
312
|
+
continue
|
|
313
|
+
output_dir = optimized_model_path.parent
|
|
314
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
315
|
+
|
|
316
|
+
if use_external_data_format is None:
|
|
317
|
+
use_external_data_format = has_external_data(onnx_model_path)
|
|
318
|
+
|
|
319
|
+
# Graph fusion before fp16 conversion, otherwise they cannot be fused later.
|
|
320
|
+
logger.info("Optimize %s ...", onnx_model_path)
|
|
321
|
+
|
|
322
|
+
args.model_type = model_type
|
|
323
|
+
fusion_options = FusionOptions.parse(args)
|
|
324
|
+
|
|
325
|
+
if model_type in ["unet"]:
|
|
326
|
+
# Some optimizations are not available in v1.14 or older version: packed QKV and BiasAdd
|
|
327
|
+
has_all_optimizations = version.parse(onnxruntime.__version__) >= version.parse("1.15.0")
|
|
328
|
+
fusion_options.enable_packed_kv = float16 and fusion_options.enable_packed_kv
|
|
329
|
+
fusion_options.enable_packed_qkv = float16 and has_all_optimizations and fusion_options.enable_packed_qkv
|
|
330
|
+
fusion_options.enable_bias_add = has_all_optimizations and fusion_options.enable_bias_add
|
|
331
|
+
|
|
332
|
+
m = optimize_model(
|
|
333
|
+
str(onnx_model_path),
|
|
334
|
+
model_type=model_type,
|
|
335
|
+
num_heads=0, # will be deduced from graph
|
|
336
|
+
hidden_size=0, # will be deduced from graph
|
|
337
|
+
opt_level=0,
|
|
338
|
+
optimization_options=fusion_options,
|
|
339
|
+
use_gpu=True,
|
|
340
|
+
provider=args.provider,
|
|
341
|
+
)
|
|
342
|
+
|
|
343
|
+
if float16:
|
|
344
|
+
model_node_block_list = (
|
|
345
|
+
flux_node_block_list if is_flux_pipeline else sd3_node_block_list if pipeline_type == "sd3" else {}
|
|
346
|
+
)
|
|
347
|
+
if name in model_node_block_list:
|
|
348
|
+
# Opset 12 does not support bfloat16.
|
|
349
|
+
# By default, optimum exports T5 model with opset 12. So we need to check the opset version.
|
|
350
|
+
use_bfloat16 = bfloat16
|
|
351
|
+
if use_bfloat16:
|
|
352
|
+
for opset in m.model.opset_import:
|
|
353
|
+
if opset.domain in ["", "ai.onnx"] and opset.version < 13:
|
|
354
|
+
logger.warning(
|
|
355
|
+
"onnx model requires opset 13 or higher to use bfloat16. Fall back to float32."
|
|
356
|
+
)
|
|
357
|
+
use_bfloat16 = False
|
|
358
|
+
|
|
359
|
+
m.convert_float_to_float16(
|
|
360
|
+
keep_io_types=False,
|
|
361
|
+
node_block_list=model_node_block_list[name],
|
|
362
|
+
use_bfloat16_as_blocked_nodes_dtype=use_bfloat16,
|
|
363
|
+
)
|
|
364
|
+
# For SD-XL, use FP16 in VAE decoder will cause NaN and black image so we keep it in FP32.
|
|
365
|
+
elif pipeline_type in ["sdxl"] and name in ["vae_decoder"]:
|
|
366
|
+
logger.info("Skip converting %s to float16 to avoid NaN", name)
|
|
367
|
+
else:
|
|
368
|
+
logger.info("Convert %s to float16 ...", name)
|
|
369
|
+
m.convert_float_to_float16(
|
|
370
|
+
keep_io_types=False,
|
|
371
|
+
op_block_list=force_fp32_operators[name],
|
|
372
|
+
)
|
|
373
|
+
|
|
374
|
+
if enable_runtime_optimization:
|
|
375
|
+
# Use this step to see the final graph that executed by Onnx Runtime.
|
|
376
|
+
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
377
|
+
# Save to a temporary file so that we can load it with Onnx Runtime.
|
|
378
|
+
logger.info("Saving a temporary model to run OnnxRuntime graph optimizations...")
|
|
379
|
+
tmp_model_path = Path(tmp_dir) / "model.onnx"
|
|
380
|
+
m.save_model_to_file(str(tmp_model_path), use_external_data_format=use_external_data_format)
|
|
381
|
+
ort_optimized_model_path = Path(tmp_dir) / "optimized.onnx"
|
|
382
|
+
optimize_by_onnxruntime(
|
|
383
|
+
str(tmp_model_path),
|
|
384
|
+
use_gpu=True,
|
|
385
|
+
provider=args.provider,
|
|
386
|
+
optimized_model_path=str(ort_optimized_model_path),
|
|
387
|
+
save_as_external_data=use_external_data_format,
|
|
388
|
+
)
|
|
389
|
+
model = onnx.load(str(ort_optimized_model_path), load_external_data=True)
|
|
390
|
+
m = model_type_class_mapping[model_type](model)
|
|
391
|
+
|
|
392
|
+
m.get_operator_statistics()
|
|
393
|
+
op_counters[name] = m.get_fused_operator_statistics()
|
|
394
|
+
m.save_model_to_file(str(optimized_model_path), use_external_data_format=use_external_data_format)
|
|
395
|
+
logger.info("%s is optimized", name)
|
|
396
|
+
logger.info("*" * 20)
|
|
397
|
+
|
|
398
|
+
return op_counters
|
|
399
|
+
|
|
400
|
+
|
|
401
|
+
def _copy_extra_directory(source_dir: Path, target_dir: Path, model_list: list[str]):
|
|
402
|
+
"""Copy extra directory that does not have onnx model
|
|
403
|
+
|
|
404
|
+
Args:
|
|
405
|
+
source_dir (Path): source directory
|
|
406
|
+
target_dir (Path): target directory
|
|
407
|
+
model_list (List[str]): list of directory names with onnx model.
|
|
408
|
+
|
|
409
|
+
Raises:
|
|
410
|
+
RuntimeError: source path does not exist
|
|
411
|
+
"""
|
|
412
|
+
extra_dirs = ["scheduler", "tokenizer", "tokenizer_2", "tokenizer_3", "feature_extractor"]
|
|
413
|
+
|
|
414
|
+
for name in extra_dirs:
|
|
415
|
+
source_path = source_dir / name
|
|
416
|
+
if not os.path.exists(source_path):
|
|
417
|
+
continue
|
|
418
|
+
|
|
419
|
+
target_path = target_dir / name
|
|
420
|
+
if target_path.exists():
|
|
421
|
+
shutil.rmtree(target_path)
|
|
422
|
+
shutil.copytree(source_path, target_path)
|
|
423
|
+
logger.info("%s => %s", source_path, target_path)
|
|
424
|
+
|
|
425
|
+
extra_files = ["model_index.json"]
|
|
426
|
+
for name in extra_files:
|
|
427
|
+
source_path = source_dir / name
|
|
428
|
+
if not os.path.exists(source_path):
|
|
429
|
+
raise RuntimeError(f"source path does not exist: {source_path}")
|
|
430
|
+
|
|
431
|
+
target_path = target_dir / name
|
|
432
|
+
shutil.copyfile(source_path, target_path)
|
|
433
|
+
logger.info("%s => %s", source_path, target_path)
|
|
434
|
+
|
|
435
|
+
# Some directory are optional
|
|
436
|
+
for onnx_model_dir in model_list:
|
|
437
|
+
source_path = source_dir / onnx_model_dir / "config.json"
|
|
438
|
+
target_path = target_dir / onnx_model_dir / "config.json"
|
|
439
|
+
if source_path.exists():
|
|
440
|
+
target_path.parent.mkdir(parents=True, exist_ok=True)
|
|
441
|
+
shutil.copyfile(source_path, target_path)
|
|
442
|
+
logger.info("%s => %s", source_path, target_path)
|
|
443
|
+
|
|
444
|
+
|
|
445
|
+
def optimize_stable_diffusion_pipeline(
|
|
446
|
+
input_dir: str,
|
|
447
|
+
output_dir: str,
|
|
448
|
+
overwrite: bool,
|
|
449
|
+
use_external_data_format: bool | None,
|
|
450
|
+
float16: bool,
|
|
451
|
+
enable_runtime_optimization: bool,
|
|
452
|
+
args,
|
|
453
|
+
):
|
|
454
|
+
if os.path.exists(output_dir):
|
|
455
|
+
if overwrite:
|
|
456
|
+
shutil.rmtree(output_dir, ignore_errors=True)
|
|
457
|
+
|
|
458
|
+
source_dir = Path(input_dir)
|
|
459
|
+
target_dir = Path(output_dir)
|
|
460
|
+
target_dir.mkdir(parents=True, exist_ok=True)
|
|
461
|
+
|
|
462
|
+
pipeline_type = _classify_pipeline_type(source_dir)
|
|
463
|
+
model_list = _get_model_list(pipeline_type)
|
|
464
|
+
|
|
465
|
+
_copy_extra_directory(source_dir, target_dir, model_list)
|
|
466
|
+
|
|
467
|
+
return _optimize_sd_pipeline(
|
|
468
|
+
source_dir,
|
|
469
|
+
target_dir,
|
|
470
|
+
pipeline_type,
|
|
471
|
+
model_list,
|
|
472
|
+
use_external_data_format,
|
|
473
|
+
float16,
|
|
474
|
+
args.bfloat16,
|
|
475
|
+
args.force_fp32_ops,
|
|
476
|
+
enable_runtime_optimization,
|
|
477
|
+
args,
|
|
478
|
+
)
|
|
479
|
+
|
|
480
|
+
|
|
481
|
+
def parse_arguments(argv: list[str] | None = None):
|
|
482
|
+
"""Parse arguments
|
|
483
|
+
|
|
484
|
+
Returns:
|
|
485
|
+
Namespace: arguments
|
|
486
|
+
"""
|
|
487
|
+
parser = argparse.ArgumentParser()
|
|
488
|
+
|
|
489
|
+
parser.add_argument(
|
|
490
|
+
"-i",
|
|
491
|
+
"--input",
|
|
492
|
+
required=True,
|
|
493
|
+
type=str,
|
|
494
|
+
help="Root of input directory of stable diffusion onnx pipeline with float32 models.",
|
|
495
|
+
)
|
|
496
|
+
|
|
497
|
+
parser.add_argument(
|
|
498
|
+
"-o",
|
|
499
|
+
"--output",
|
|
500
|
+
required=True,
|
|
501
|
+
type=str,
|
|
502
|
+
help="Root of output directory of stable diffusion onnx pipeline with optimized models.",
|
|
503
|
+
)
|
|
504
|
+
|
|
505
|
+
parser.add_argument(
|
|
506
|
+
"--float16",
|
|
507
|
+
required=False,
|
|
508
|
+
action="store_true",
|
|
509
|
+
help="Output models of float16, except some nodes falls back to float32 or bfloat16 to avoid overflow.",
|
|
510
|
+
)
|
|
511
|
+
parser.set_defaults(float16=False)
|
|
512
|
+
|
|
513
|
+
parser.add_argument(
|
|
514
|
+
"--bfloat16",
|
|
515
|
+
required=False,
|
|
516
|
+
action="store_true",
|
|
517
|
+
help="Allow bfloat16 as fallback if --float16 is also provided.",
|
|
518
|
+
)
|
|
519
|
+
parser.set_defaults(bfloat16=False)
|
|
520
|
+
|
|
521
|
+
parser.add_argument(
|
|
522
|
+
"--force_fp32_ops",
|
|
523
|
+
required=False,
|
|
524
|
+
nargs="+",
|
|
525
|
+
type=str,
|
|
526
|
+
help="Force given operators (like unet:Attention) to run in float32. It is case sensitive!",
|
|
527
|
+
)
|
|
528
|
+
|
|
529
|
+
parser.add_argument(
|
|
530
|
+
"--inspect",
|
|
531
|
+
required=False,
|
|
532
|
+
action="store_true",
|
|
533
|
+
help="Save the optimized graph from Onnx Runtime. "
|
|
534
|
+
"This option has no impact on inference performance except it might reduce session creation time.",
|
|
535
|
+
)
|
|
536
|
+
parser.set_defaults(inspect=False)
|
|
537
|
+
|
|
538
|
+
parser.add_argument(
|
|
539
|
+
"--overwrite",
|
|
540
|
+
required=False,
|
|
541
|
+
action="store_true",
|
|
542
|
+
help="Overwrite exists files.",
|
|
543
|
+
)
|
|
544
|
+
parser.set_defaults(overwrite=False)
|
|
545
|
+
|
|
546
|
+
parser.add_argument(
|
|
547
|
+
"-e",
|
|
548
|
+
"--use_external_data_format",
|
|
549
|
+
required=False,
|
|
550
|
+
action="store_true",
|
|
551
|
+
help="Onnx model larger than 2GB need to use external data format. "
|
|
552
|
+
"If specified, save each onnx model to two files: one for onnx graph, another for weights. "
|
|
553
|
+
"If not specified, use same format as original model by default. ",
|
|
554
|
+
)
|
|
555
|
+
parser.set_defaults(use_external_data_format=None)
|
|
556
|
+
|
|
557
|
+
parser.add_argument(
|
|
558
|
+
"--provider",
|
|
559
|
+
required=False,
|
|
560
|
+
type=str,
|
|
561
|
+
default=None,
|
|
562
|
+
help="Execution provider to use.",
|
|
563
|
+
)
|
|
564
|
+
|
|
565
|
+
FusionOptions.add_arguments(parser)
|
|
566
|
+
|
|
567
|
+
args = parser.parse_args(argv)
|
|
568
|
+
return args
|
|
569
|
+
|
|
570
|
+
|
|
571
|
+
def main(argv: list[str] | None = None):
|
|
572
|
+
warnings.warn(
|
|
573
|
+
"This example is deprecated. Use the Olive recipe instead: "
|
|
574
|
+
"https://github.com/microsoft/olive-recipes/tree/main",
|
|
575
|
+
DeprecationWarning,
|
|
576
|
+
stacklevel=2,
|
|
577
|
+
)
|
|
578
|
+
args = parse_arguments(argv)
|
|
579
|
+
|
|
580
|
+
logger.info("Arguments: %s", str(args))
|
|
581
|
+
|
|
582
|
+
# Return op counters for testing purpose.
|
|
583
|
+
return optimize_stable_diffusion_pipeline(
|
|
584
|
+
args.input, args.output, args.overwrite, args.use_external_data_format, args.float16, args.inspect, args
|
|
585
|
+
)
|
|
586
|
+
|
|
587
|
+
|
|
588
|
+
if __name__ == "__main__":
|
|
589
|
+
logging.basicConfig(format="%(funcName)20s: %(message)s", level=logging.INFO)
|
|
590
|
+
main()
|