onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxruntime/LICENSE +21 -0
- onnxruntime/Privacy.md +21 -0
- onnxruntime/ThirdPartyNotices.txt +6121 -0
- onnxruntime/__init__.py +418 -0
- onnxruntime/backend/__init__.py +6 -0
- onnxruntime/backend/backend.py +175 -0
- onnxruntime/backend/backend_rep.py +52 -0
- onnxruntime/capi/DirectML.dll +0 -0
- onnxruntime/capi/__init__.py +4 -0
- onnxruntime/capi/_ld_preload.py +7 -0
- onnxruntime/capi/_pybind_state.py +33 -0
- onnxruntime/capi/build_and_package_info.py +2 -0
- onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
- onnxruntime/capi/onnxruntime.dll +0 -0
- onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
- onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
- onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
- onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
- onnxruntime/capi/onnxruntime_validation.py +154 -0
- onnxruntime/capi/version_info.py +2 -0
- onnxruntime/datasets/__init__.py +18 -0
- onnxruntime/datasets/logreg_iris.onnx +0 -0
- onnxruntime/datasets/mul_1.onnx +0 -0
- onnxruntime/datasets/sigmoid.onnx +13 -0
- onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
- onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
- onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
- onnxruntime/quantization/__init__.py +19 -0
- onnxruntime/quantization/base_quantizer.py +529 -0
- onnxruntime/quantization/calibrate.py +1267 -0
- onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
- onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
- onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
- onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
- onnxruntime/quantization/fusions/__init__.py +4 -0
- onnxruntime/quantization/fusions/fusion.py +311 -0
- onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
- onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
- onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
- onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
- onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
- onnxruntime/quantization/neural_compressor/__init__.py +1 -0
- onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
- onnxruntime/quantization/neural_compressor/util.py +80 -0
- onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
- onnxruntime/quantization/onnx_model.py +600 -0
- onnxruntime/quantization/onnx_quantizer.py +1163 -0
- onnxruntime/quantization/operators/__init__.py +2 -0
- onnxruntime/quantization/operators/activation.py +119 -0
- onnxruntime/quantization/operators/argmax.py +18 -0
- onnxruntime/quantization/operators/attention.py +73 -0
- onnxruntime/quantization/operators/base_operator.py +26 -0
- onnxruntime/quantization/operators/binary_op.py +72 -0
- onnxruntime/quantization/operators/concat.py +62 -0
- onnxruntime/quantization/operators/conv.py +260 -0
- onnxruntime/quantization/operators/direct_q8.py +78 -0
- onnxruntime/quantization/operators/embed_layernorm.py +121 -0
- onnxruntime/quantization/operators/gather.py +64 -0
- onnxruntime/quantization/operators/gavgpool.py +62 -0
- onnxruntime/quantization/operators/gemm.py +172 -0
- onnxruntime/quantization/operators/lstm.py +121 -0
- onnxruntime/quantization/operators/matmul.py +231 -0
- onnxruntime/quantization/operators/maxpool.py +34 -0
- onnxruntime/quantization/operators/norm.py +40 -0
- onnxruntime/quantization/operators/pad.py +172 -0
- onnxruntime/quantization/operators/pooling.py +67 -0
- onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
- onnxruntime/quantization/operators/resize.py +34 -0
- onnxruntime/quantization/operators/softmax.py +74 -0
- onnxruntime/quantization/operators/split.py +63 -0
- onnxruntime/quantization/operators/where.py +87 -0
- onnxruntime/quantization/preprocess.py +141 -0
- onnxruntime/quantization/qdq_loss_debug.py +389 -0
- onnxruntime/quantization/qdq_quantizer.py +1477 -0
- onnxruntime/quantization/quant_utils.py +1051 -0
- onnxruntime/quantization/quantize.py +953 -0
- onnxruntime/quantization/registry.py +110 -0
- onnxruntime/quantization/shape_inference.py +204 -0
- onnxruntime/quantization/static_quantize_runner.py +256 -0
- onnxruntime/quantization/tensor_quant_overrides.py +520 -0
- onnxruntime/tools/__init__.py +10 -0
- onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
- onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
- onnxruntime/tools/file_utils.py +47 -0
- onnxruntime/tools/logger.py +11 -0
- onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
- onnxruntime/tools/mobile_helpers/__init__.py +0 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
- onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
- onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
- onnxruntime/tools/offline_tuning.py +169 -0
- onnxruntime/tools/onnx_model_utils.py +416 -0
- onnxruntime/tools/onnx_randomizer.py +85 -0
- onnxruntime/tools/onnxruntime_test.py +164 -0
- onnxruntime/tools/optimize_onnx_model.py +56 -0
- onnxruntime/tools/ort_format_model/__init__.py +27 -0
- onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
- onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
- onnxruntime/tools/ort_format_model/types.py +85 -0
- onnxruntime/tools/ort_format_model/utils.py +61 -0
- onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
- onnxruntime/tools/pytorch_export_helpers.py +131 -0
- onnxruntime/tools/qdq_helpers/__init__.py +0 -0
- onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
- onnxruntime/tools/qnn/add_trans_cast.py +292 -0
- onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
- onnxruntime/tools/qnn/preprocess.py +165 -0
- onnxruntime/tools/reduced_build_config_parser.py +203 -0
- onnxruntime/tools/remove_initializer_from_input.py +37 -0
- onnxruntime/tools/symbolic_shape_infer.py +3094 -0
- onnxruntime/tools/update_onnx_opset.py +31 -0
- onnxruntime/transformers/__init__.py +8 -0
- onnxruntime/transformers/affinity_helper.py +40 -0
- onnxruntime/transformers/benchmark.py +942 -0
- onnxruntime/transformers/benchmark_helper.py +643 -0
- onnxruntime/transformers/bert_perf_test.py +629 -0
- onnxruntime/transformers/bert_test_data.py +641 -0
- onnxruntime/transformers/compare_bert_results.py +256 -0
- onnxruntime/transformers/constants.py +47 -0
- onnxruntime/transformers/convert_generation.py +3605 -0
- onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
- onnxruntime/transformers/convert_to_packing_mode.py +385 -0
- onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
- onnxruntime/transformers/float16.py +501 -0
- onnxruntime/transformers/fusion_attention.py +1189 -0
- onnxruntime/transformers/fusion_attention_clip.py +340 -0
- onnxruntime/transformers/fusion_attention_sam2.py +533 -0
- onnxruntime/transformers/fusion_attention_unet.py +1307 -0
- onnxruntime/transformers/fusion_attention_vae.py +300 -0
- onnxruntime/transformers/fusion_bart_attention.py +435 -0
- onnxruntime/transformers/fusion_base.py +141 -0
- onnxruntime/transformers/fusion_bias_add.py +57 -0
- onnxruntime/transformers/fusion_biasgelu.py +66 -0
- onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
- onnxruntime/transformers/fusion_conformer_attention.py +222 -0
- onnxruntime/transformers/fusion_constant_fold.py +144 -0
- onnxruntime/transformers/fusion_embedlayer.py +810 -0
- onnxruntime/transformers/fusion_fastgelu.py +492 -0
- onnxruntime/transformers/fusion_gelu.py +258 -0
- onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
- onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
- onnxruntime/transformers/fusion_gpt_attention.py +546 -0
- onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
- onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
- onnxruntime/transformers/fusion_group_norm.py +180 -0
- onnxruntime/transformers/fusion_layernorm.py +489 -0
- onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
- onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
- onnxruntime/transformers/fusion_options.py +340 -0
- onnxruntime/transformers/fusion_qordered_attention.py +420 -0
- onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
- onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
- onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
- onnxruntime/transformers/fusion_quickgelu.py +74 -0
- onnxruntime/transformers/fusion_reshape.py +173 -0
- onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
- onnxruntime/transformers/fusion_shape.py +109 -0
- onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
- onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
- onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
- onnxruntime/transformers/fusion_transpose.py +167 -0
- onnxruntime/transformers/fusion_utils.py +321 -0
- onnxruntime/transformers/huggingface_models.py +74 -0
- onnxruntime/transformers/import_utils.py +20 -0
- onnxruntime/transformers/io_binding_helper.py +487 -0
- onnxruntime/transformers/large_model_exporter.py +395 -0
- onnxruntime/transformers/machine_info.py +230 -0
- onnxruntime/transformers/metrics.py +163 -0
- onnxruntime/transformers/models/bart/__init__.py +12 -0
- onnxruntime/transformers/models/bart/export.py +98 -0
- onnxruntime/transformers/models/bert/__init__.py +12 -0
- onnxruntime/transformers/models/bert/eval_squad.py +329 -0
- onnxruntime/transformers/models/gpt2/__init__.py +12 -0
- onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
- onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
- onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
- onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
- onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
- onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
- onnxruntime/transformers/models/llama/__init__.py +12 -0
- onnxruntime/transformers/models/llama/benchmark.py +700 -0
- onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
- onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
- onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
- onnxruntime/transformers/models/llama/dist_settings.py +57 -0
- onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
- onnxruntime/transformers/models/llama/llama_parity.py +343 -0
- onnxruntime/transformers/models/llama/llama_torch.py +47 -0
- onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
- onnxruntime/transformers/models/longformer/__init__.py +12 -0
- onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
- onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
- onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
- onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
- onnxruntime/transformers/models/phi2/__init__.py +12 -0
- onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
- onnxruntime/transformers/models/phi2/inference_example.py +414 -0
- onnxruntime/transformers/models/sam2/__init__.py +12 -0
- onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
- onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
- onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
- onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
- onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
- onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
- onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
- onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
- onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
- onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
- onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
- onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
- onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
- onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
- onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
- onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
- onnxruntime/transformers/models/t5/__init__.py +12 -0
- onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
- onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
- onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
- onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
- onnxruntime/transformers/models/t5/t5_helper.py +302 -0
- onnxruntime/transformers/models/whisper/__init__.py +12 -0
- onnxruntime/transformers/models/whisper/benchmark.py +585 -0
- onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
- onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
- onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
- onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
- onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
- onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
- onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
- onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
- onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
- onnxruntime/transformers/onnx_exporter.py +719 -0
- onnxruntime/transformers/onnx_model.py +1636 -0
- onnxruntime/transformers/onnx_model_bart.py +141 -0
- onnxruntime/transformers/onnx_model_bert.py +488 -0
- onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
- onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
- onnxruntime/transformers/onnx_model_clip.py +42 -0
- onnxruntime/transformers/onnx_model_conformer.py +32 -0
- onnxruntime/transformers/onnx_model_gpt2.py +101 -0
- onnxruntime/transformers/onnx_model_mmdit.py +112 -0
- onnxruntime/transformers/onnx_model_phi.py +929 -0
- onnxruntime/transformers/onnx_model_sam2.py +137 -0
- onnxruntime/transformers/onnx_model_t5.py +985 -0
- onnxruntime/transformers/onnx_model_tnlr.py +226 -0
- onnxruntime/transformers/onnx_model_unet.py +258 -0
- onnxruntime/transformers/onnx_model_vae.py +42 -0
- onnxruntime/transformers/onnx_utils.py +55 -0
- onnxruntime/transformers/optimizer.py +620 -0
- onnxruntime/transformers/past_helper.py +149 -0
- onnxruntime/transformers/profile_result_processor.py +358 -0
- onnxruntime/transformers/profiler.py +434 -0
- onnxruntime/transformers/quantize_helper.py +76 -0
- onnxruntime/transformers/shape_infer_helper.py +121 -0
- onnxruntime/transformers/shape_optimizer.py +400 -0
- onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
- onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
- onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
- onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
- onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
- onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,501 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License.
|
|
4
|
+
# --------------------------------------------------------------------------
|
|
5
|
+
|
|
6
|
+
# This file is modified from https://github.com/microsoft/onnxconverter-common/blob/master/onnxconverter_common/float16.py
|
|
7
|
+
# Modifications:
|
|
8
|
+
# (1) Update default value of min_positive_val and max_finite_val
|
|
9
|
+
# (2) keep_io_types can be list of names
|
|
10
|
+
# (3) convert initializers if needed to preserve precision
|
|
11
|
+
# (4) add force_fp16_initializers option
|
|
12
|
+
# (5) handle Resize and GroupNorm with mixed float inputs
|
|
13
|
+
# (6) allow convert_float_to_float16 to accept model path
|
|
14
|
+
|
|
15
|
+
import itertools
|
|
16
|
+
import logging
|
|
17
|
+
import os
|
|
18
|
+
import tempfile
|
|
19
|
+
|
|
20
|
+
import numpy as np
|
|
21
|
+
import onnx
|
|
22
|
+
from onnx import AttributeProto, GraphProto, ModelProto, NodeProto, TensorProto, helper, numpy_helper
|
|
23
|
+
from onnx.shape_inference import infer_shapes, infer_shapes_path
|
|
24
|
+
from packaging import version
|
|
25
|
+
|
|
26
|
+
logger = logging.getLogger(__name__)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def _npfloat16_to_int(np_list):
|
|
30
|
+
"""
|
|
31
|
+
Convert numpy float16 to python int.
|
|
32
|
+
|
|
33
|
+
:param np_list: numpy float16 list
|
|
34
|
+
:return int_list: python int list
|
|
35
|
+
"""
|
|
36
|
+
return [int(bin(_.view("H"))[2:].zfill(16), 2) for _ in np_list]
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def convert_np_to_float16(np_array, min_positive_val=5.96e-08, max_finite_val=65504.0):
|
|
40
|
+
"""
|
|
41
|
+
Convert float32 numpy array to float16 without changing sign or finiteness.
|
|
42
|
+
Positive values less than min_positive_val are mapped to min_positive_val.
|
|
43
|
+
Positive finite values greater than max_finite_val are mapped to max_finite_val.
|
|
44
|
+
Similar for negative values. NaN, 0, inf, and -inf are unchanged.
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
def between(a, b, c):
|
|
48
|
+
return np.logical_and(a < b, b < c)
|
|
49
|
+
|
|
50
|
+
if np_array[np.where(np_array > 0)].shape[0] > 0:
|
|
51
|
+
positive_max = np_array[np.where(np_array > 0)].max()
|
|
52
|
+
positive_min = np_array[np.where(np_array > 0)].min()
|
|
53
|
+
if positive_max >= max_finite_val:
|
|
54
|
+
logger.debug(f"the float32 number {positive_max} will be truncated to {max_finite_val}")
|
|
55
|
+
if positive_min <= min_positive_val:
|
|
56
|
+
logger.debug(f"the float32 number {positive_min} will be truncated to {min_positive_val}")
|
|
57
|
+
|
|
58
|
+
if np_array[np.where(np_array < 0)].shape[0] > 0:
|
|
59
|
+
negative_max = np_array[np.where(np_array < 0)].max()
|
|
60
|
+
negative_min = np_array[np.where(np_array < 0)].min()
|
|
61
|
+
if negative_min <= -max_finite_val:
|
|
62
|
+
logger.debug(f"the float32 number {negative_min} will be truncated to {-max_finite_val}")
|
|
63
|
+
if negative_max >= -min_positive_val:
|
|
64
|
+
logger.debug(f"the float32 number {negative_max} will be truncated to {-min_positive_val}")
|
|
65
|
+
|
|
66
|
+
np_array = np.where(between(0, np_array, min_positive_val), min_positive_val, np_array)
|
|
67
|
+
np_array = np.where(between(-min_positive_val, np_array, 0), -min_positive_val, np_array)
|
|
68
|
+
np_array = np.where(between(max_finite_val, np_array, float("inf")), max_finite_val, np_array)
|
|
69
|
+
np_array = np.where(between(float("-inf"), np_array, -max_finite_val), -max_finite_val, np_array)
|
|
70
|
+
return np.float16(np_array)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def convert_tensor_float_to_float16(tensor, min_positive_val=5.96e-08, max_finite_val=65504.0):
|
|
74
|
+
"""Convert tensor float to float16.
|
|
75
|
+
|
|
76
|
+
Args:
|
|
77
|
+
tensor (TensorProto): the tensor to convert.
|
|
78
|
+
min_positive_val (float, optional): minimal positive value. Defaults to 1e-7.
|
|
79
|
+
max_finite_val (float, optional): maximal finite value. Defaults to 1e4.
|
|
80
|
+
|
|
81
|
+
Raises:
|
|
82
|
+
ValueError: input type is not TensorProto.
|
|
83
|
+
|
|
84
|
+
Returns:
|
|
85
|
+
TensorProto: the converted tensor.
|
|
86
|
+
"""
|
|
87
|
+
|
|
88
|
+
if not isinstance(tensor, TensorProto):
|
|
89
|
+
raise ValueError(f"Expected input type is an ONNX TensorProto but got {type(tensor)}")
|
|
90
|
+
|
|
91
|
+
if tensor.data_type == TensorProto.FLOAT:
|
|
92
|
+
tensor.data_type = TensorProto.FLOAT16
|
|
93
|
+
# convert float_data (float type) to float16 and write to int32_data
|
|
94
|
+
if tensor.float_data:
|
|
95
|
+
float16_data = convert_np_to_float16(np.array(tensor.float_data), min_positive_val, max_finite_val)
|
|
96
|
+
int_list = _npfloat16_to_int(float16_data)
|
|
97
|
+
tensor.int32_data[:] = int_list
|
|
98
|
+
tensor.float_data[:] = []
|
|
99
|
+
# convert raw_data (bytes type)
|
|
100
|
+
if tensor.raw_data:
|
|
101
|
+
# convert n.raw_data to float
|
|
102
|
+
float32_list = np.frombuffer(tensor.raw_data, dtype="float32")
|
|
103
|
+
# convert float to float16
|
|
104
|
+
float16_list = convert_np_to_float16(float32_list, min_positive_val, max_finite_val)
|
|
105
|
+
# convert float16 to bytes and write back to raw_data
|
|
106
|
+
tensor.raw_data = float16_list.tobytes()
|
|
107
|
+
return tensor
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def make_value_info_from_tensor(tensor):
|
|
111
|
+
shape = numpy_helper.to_array(tensor).shape
|
|
112
|
+
return helper.make_tensor_value_info(tensor.name, tensor.data_type, shape)
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
DEFAULT_OP_BLOCK_LIST = [
|
|
116
|
+
"ArrayFeatureExtractor",
|
|
117
|
+
"Binarizer",
|
|
118
|
+
"CastMap",
|
|
119
|
+
"CategoryMapper",
|
|
120
|
+
"DictVectorizer",
|
|
121
|
+
"FeatureVectorizer",
|
|
122
|
+
"Imputer",
|
|
123
|
+
"LabelEncoder",
|
|
124
|
+
"LinearClassifier",
|
|
125
|
+
"LinearRegressor",
|
|
126
|
+
"Normalizer",
|
|
127
|
+
"OneHotEncoder",
|
|
128
|
+
"RandomUniformLike",
|
|
129
|
+
"SVMClassifier",
|
|
130
|
+
"SVMRegressor",
|
|
131
|
+
"Scaler",
|
|
132
|
+
"TreeEnsembleClassifier",
|
|
133
|
+
"TreeEnsembleRegressor",
|
|
134
|
+
"TreeEnsemble",
|
|
135
|
+
"ZipMap",
|
|
136
|
+
"NonMaxSuppression",
|
|
137
|
+
"TopK",
|
|
138
|
+
"RoiAlign",
|
|
139
|
+
"Range",
|
|
140
|
+
"CumSum",
|
|
141
|
+
"Min",
|
|
142
|
+
"Max",
|
|
143
|
+
"Upsample",
|
|
144
|
+
]
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
# Some operators has data type fixed as float for some inputs. Key is op_type, value is list of input indices
|
|
148
|
+
# Note that DirectML allows float16 gamma and beta in GroupNorm. Use force_fp16_inputs parameter could overwrite this.
|
|
149
|
+
ALWAYS_FLOAT_INPUTS = {"Resize": [2], "GroupNorm": [1, 2], "SkipGroupNorm": [1, 2]}
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
class InitializerTracker:
|
|
153
|
+
"""Class for keeping track of initializer."""
|
|
154
|
+
|
|
155
|
+
def __init__(self, initializer: TensorProto):
|
|
156
|
+
self.initializer = initializer
|
|
157
|
+
self.fp32_nodes = []
|
|
158
|
+
self.fp16_nodes = []
|
|
159
|
+
|
|
160
|
+
def add_node(self, node: NodeProto, is_node_blocked):
|
|
161
|
+
if is_node_blocked:
|
|
162
|
+
self.fp32_nodes.append(node)
|
|
163
|
+
else:
|
|
164
|
+
self.fp16_nodes.append(node)
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def convert_float_to_float16(
|
|
168
|
+
model,
|
|
169
|
+
min_positive_val=5.96e-08,
|
|
170
|
+
max_finite_val=65504.0,
|
|
171
|
+
keep_io_types=False,
|
|
172
|
+
disable_shape_infer=False,
|
|
173
|
+
op_block_list=None,
|
|
174
|
+
node_block_list=None,
|
|
175
|
+
force_fp16_initializers=False,
|
|
176
|
+
force_fp16_inputs=None,
|
|
177
|
+
use_bfloat16_as_blocked_nodes_dtype=False,
|
|
178
|
+
):
|
|
179
|
+
"""Convert tensor float type in the input ONNX model to tensor float16.
|
|
180
|
+
|
|
181
|
+
Args:
|
|
182
|
+
model (ModelProto or str): The ONNX model or path of the model to convert.
|
|
183
|
+
min_positive_val (float, optional): minimal positive value. Defaults to 5.96e-08.
|
|
184
|
+
max_finite_val (float, optional): maximal finite value of float16. Defaults to 65504.
|
|
185
|
+
keep_io_types (Union[bool, List[str]], optional): It could be boolean or a list of float32 input/output names.
|
|
186
|
+
If True, model inputs/outputs should be left as float32.
|
|
187
|
+
Defaults to False.
|
|
188
|
+
disable_shape_infer (bool, optional): Skips running onnx shape/type inference.
|
|
189
|
+
Useful if shape inference has been done. Defaults to False.
|
|
190
|
+
op_block_list (List[str], optional): List of op types to leave as float32.
|
|
191
|
+
Defaults to None, which will use `float16.DEFAULT_OP_BLOCK_LIST`.
|
|
192
|
+
node_block_list (List[str], optional): List of node names to leave as float32. Defaults to None.
|
|
193
|
+
force_fp16_initializers(bool): force converting all float initializers to float16.
|
|
194
|
+
Default to false, which will convert only the one needed to avoid precision loss.
|
|
195
|
+
force_fp16_inputs(Dict[str, List[int]]): Force the conversion of the inputs of some operators to float16, even if
|
|
196
|
+
this script's preference it to keep them in float32.
|
|
197
|
+
Raises:
|
|
198
|
+
ValueError: input type is not ModelProto.
|
|
199
|
+
|
|
200
|
+
Returns:
|
|
201
|
+
ModelProto: converted model.
|
|
202
|
+
"""
|
|
203
|
+
assert min_positive_val >= 5.96e-08, (
|
|
204
|
+
"invalid min_positive_val. smallest positive float16 value: subnormal 5.96e-08, and normalized 6.104e-05"
|
|
205
|
+
)
|
|
206
|
+
assert max_finite_val <= float(np.finfo(np.float16).max), "invalid max_finite_val. largest float16 value: 65504"
|
|
207
|
+
|
|
208
|
+
force_fp16_inputs_dict = {} if force_fp16_inputs is None else force_fp16_inputs
|
|
209
|
+
|
|
210
|
+
if isinstance(model, str):
|
|
211
|
+
model_path = model
|
|
212
|
+
if version.parse(onnx.__version__) >= version.parse("1.8.0") and not disable_shape_infer:
|
|
213
|
+
# shape_infer_model_path should be in the same folder of model_path
|
|
214
|
+
with tempfile.NamedTemporaryFile(dir=os.path.dirname(model_path)) as tmpfile:
|
|
215
|
+
shape_infer_model_path = tmpfile.name
|
|
216
|
+
# infer_shapes_path can be used for model >2GB, and infer_shapes cannot.
|
|
217
|
+
infer_shapes_path(model_path, shape_infer_model_path)
|
|
218
|
+
model = onnx.load(shape_infer_model_path)
|
|
219
|
+
disable_shape_infer = True
|
|
220
|
+
else:
|
|
221
|
+
model = onnx.load(model_path)
|
|
222
|
+
|
|
223
|
+
if not isinstance(model, ModelProto):
|
|
224
|
+
raise ValueError(f"Expected an ONNX ModelProto but got {type(model)}")
|
|
225
|
+
|
|
226
|
+
func_infer_shape = None
|
|
227
|
+
if not disable_shape_infer and version.parse(onnx.__version__) >= version.parse("1.2.0"):
|
|
228
|
+
try:
|
|
229
|
+
func_infer_shape = infer_shapes
|
|
230
|
+
finally:
|
|
231
|
+
pass
|
|
232
|
+
|
|
233
|
+
# create blocklists
|
|
234
|
+
if op_block_list is None:
|
|
235
|
+
op_block_list = DEFAULT_OP_BLOCK_LIST
|
|
236
|
+
if node_block_list is None:
|
|
237
|
+
node_block_list = []
|
|
238
|
+
op_block_list = set(op_block_list)
|
|
239
|
+
node_block_list = set(node_block_list)
|
|
240
|
+
|
|
241
|
+
logger.debug(
|
|
242
|
+
f"fp16 parameters: min_positive_val={min_positive_val} max_finite_val={max_finite_val} keep_io_types={keep_io_types} disable_shape_infer={disable_shape_infer} op_block_list={op_block_list} node_block_list={node_block_list} force_fp16_initializers={force_fp16_initializers}"
|
|
243
|
+
)
|
|
244
|
+
|
|
245
|
+
# create a queue for BFS
|
|
246
|
+
queue = []
|
|
247
|
+
value_info_list = []
|
|
248
|
+
node_list = []
|
|
249
|
+
|
|
250
|
+
# Some operators (Like Resize or GroupNorm) have data type fixed as float for some input.
|
|
251
|
+
# When it is converted to float16, there are mixed types: some inputs are float32 and some are float16.
|
|
252
|
+
# This list keeps track of such nodes that are not in block list.
|
|
253
|
+
mixed_float_type_node_list = []
|
|
254
|
+
|
|
255
|
+
# type inference on input model
|
|
256
|
+
if func_infer_shape is not None:
|
|
257
|
+
model = func_infer_shape(model)
|
|
258
|
+
queue.append(model)
|
|
259
|
+
name_mapping = {}
|
|
260
|
+
graph_io_to_skip = set()
|
|
261
|
+
io_casts = set()
|
|
262
|
+
|
|
263
|
+
fp32_inputs = [n.name for n in model.graph.input if n.type.tensor_type.elem_type == TensorProto.FLOAT]
|
|
264
|
+
fp32_outputs = [n.name for n in model.graph.output if n.type.tensor_type.elem_type == TensorProto.FLOAT]
|
|
265
|
+
if isinstance(keep_io_types, list):
|
|
266
|
+
fp32_inputs = [n for n in fp32_inputs if n in keep_io_types]
|
|
267
|
+
fp32_outputs = [n for n in fp32_outputs if n in keep_io_types]
|
|
268
|
+
elif not keep_io_types:
|
|
269
|
+
fp32_inputs = []
|
|
270
|
+
fp32_outputs = []
|
|
271
|
+
|
|
272
|
+
for i, n in enumerate(model.graph.input):
|
|
273
|
+
if n.name in fp32_inputs:
|
|
274
|
+
output_name = "graph_input_cast_" + str(i)
|
|
275
|
+
name_mapping[n.name] = output_name
|
|
276
|
+
graph_io_to_skip.add(n.name)
|
|
277
|
+
|
|
278
|
+
node_name = "graph_input_cast" + str(i)
|
|
279
|
+
new_value_info = model.graph.value_info.add()
|
|
280
|
+
new_value_info.CopyFrom(n)
|
|
281
|
+
new_value_info.name = output_name
|
|
282
|
+
new_value_info.type.tensor_type.elem_type = TensorProto.FLOAT16
|
|
283
|
+
# add Cast node (from tensor(float) to tensor(float16) after graph input
|
|
284
|
+
new_node = [helper.make_node("Cast", [n.name], [output_name], to=TensorProto.FLOAT16, name=node_name)]
|
|
285
|
+
model.graph.node.extend(new_node)
|
|
286
|
+
value_info_list.append(new_value_info)
|
|
287
|
+
io_casts.add(node_name)
|
|
288
|
+
|
|
289
|
+
for i, n in enumerate(model.graph.output):
|
|
290
|
+
if n.name in fp32_outputs:
|
|
291
|
+
input_name = "graph_output_cast_" + str(i)
|
|
292
|
+
name_mapping[n.name] = input_name
|
|
293
|
+
graph_io_to_skip.add(n.name)
|
|
294
|
+
|
|
295
|
+
node_name = "graph_output_cast" + str(i)
|
|
296
|
+
# add Cast node (from tensor(float16) to tensor(float) before graph output
|
|
297
|
+
new_value_info = model.graph.value_info.add()
|
|
298
|
+
new_value_info.CopyFrom(n)
|
|
299
|
+
new_value_info.name = input_name
|
|
300
|
+
new_value_info.type.tensor_type.elem_type = TensorProto.FLOAT16
|
|
301
|
+
new_node = [helper.make_node("Cast", [input_name], [n.name], to=1, name=node_name)]
|
|
302
|
+
model.graph.node.extend(new_node)
|
|
303
|
+
value_info_list.append(new_value_info)
|
|
304
|
+
io_casts.add(node_name)
|
|
305
|
+
|
|
306
|
+
fp32_initializers: dict[str, InitializerTracker] = {}
|
|
307
|
+
while queue:
|
|
308
|
+
next_level = []
|
|
309
|
+
for q in queue:
|
|
310
|
+
# if q is model, push q.graph (GraphProto)
|
|
311
|
+
if isinstance(q, ModelProto):
|
|
312
|
+
next_level.append(q.graph)
|
|
313
|
+
# if q is model.graph, push q.node.attribute (AttributeProto)
|
|
314
|
+
if isinstance(q, GraphProto):
|
|
315
|
+
for n in q.initializer: # TensorProto type
|
|
316
|
+
if n.data_type == TensorProto.FLOAT:
|
|
317
|
+
assert n.name not in fp32_initializers
|
|
318
|
+
fp32_initializers[n.name] = InitializerTracker(n)
|
|
319
|
+
|
|
320
|
+
for n in q.node:
|
|
321
|
+
# if n is in the block list (doesn't support float16), no conversion for the node,
|
|
322
|
+
# and save the node for further processing
|
|
323
|
+
if n.name in io_casts:
|
|
324
|
+
continue
|
|
325
|
+
for i in range(len(n.input)):
|
|
326
|
+
if n.input[i] in name_mapping:
|
|
327
|
+
n.input[i] = name_mapping[n.input[i]]
|
|
328
|
+
for i in range(len(n.output)):
|
|
329
|
+
if n.output[i] in name_mapping:
|
|
330
|
+
n.output[i] = name_mapping[n.output[i]]
|
|
331
|
+
|
|
332
|
+
is_node_blocked = n.op_type in op_block_list or n.name in node_block_list
|
|
333
|
+
for i, input_name in enumerate(n.input):
|
|
334
|
+
if input_name in fp32_initializers:
|
|
335
|
+
# For Resize/GroupNorm, only the first input can be float16
|
|
336
|
+
use_fp32_weight = is_node_blocked or (
|
|
337
|
+
i in ALWAYS_FLOAT_INPUTS.get(n.op_type, [])
|
|
338
|
+
and i not in force_fp16_inputs_dict.get(n.op_type, [])
|
|
339
|
+
)
|
|
340
|
+
fp32_initializers[input_name].add_node(n, use_fp32_weight)
|
|
341
|
+
|
|
342
|
+
if is_node_blocked:
|
|
343
|
+
node_list.append(n)
|
|
344
|
+
else:
|
|
345
|
+
if n.op_type == "Cast":
|
|
346
|
+
for attr in n.attribute:
|
|
347
|
+
if attr.name == "to" and attr.i == TensorProto.FLOAT:
|
|
348
|
+
attr.i = TensorProto.FLOAT16
|
|
349
|
+
break
|
|
350
|
+
|
|
351
|
+
if n.op_type in [
|
|
352
|
+
"EyeLike",
|
|
353
|
+
"Multinomial",
|
|
354
|
+
"RandomNormal",
|
|
355
|
+
"RandomNormalLike",
|
|
356
|
+
"RandomUniform",
|
|
357
|
+
"RandomUniformLike",
|
|
358
|
+
"SequenceEmpty",
|
|
359
|
+
"Bernoulli",
|
|
360
|
+
]:
|
|
361
|
+
has_dtype = False
|
|
362
|
+
for attr in n.attribute:
|
|
363
|
+
if attr.name == "dtype":
|
|
364
|
+
has_dtype = True
|
|
365
|
+
if attr.i == TensorProto.FLOAT:
|
|
366
|
+
attr.i = TensorProto.FLOAT16
|
|
367
|
+
|
|
368
|
+
# The dtype attribute is optional and default is FLOAT in the following operators
|
|
369
|
+
# so we need add dtype attribute to specify the data type float16
|
|
370
|
+
if (n.op_type in ["RandomNormal", "RandomUniform", "SequenceEmpty"]) and not has_dtype:
|
|
371
|
+
n.attribute.extend([helper.make_attribute("dtype", TensorProto.FLOAT16)])
|
|
372
|
+
|
|
373
|
+
# For Resize/GroupNorm, attribute data type cannot be changed
|
|
374
|
+
if n.op_type not in ALWAYS_FLOAT_INPUTS or n.op_type in force_fp16_inputs_dict:
|
|
375
|
+
for attr in n.attribute:
|
|
376
|
+
next_level.append(attr) # noqa: PERF402
|
|
377
|
+
else:
|
|
378
|
+
mixed_float_type_node_list.append(n)
|
|
379
|
+
|
|
380
|
+
# if q is model.graph.node.attribute, push q.g and q.graphs (GraphProto)
|
|
381
|
+
# and process node.attribute.t and node.attribute.tensors (TensorProto)
|
|
382
|
+
if isinstance(q, AttributeProto):
|
|
383
|
+
next_level.append(q.g)
|
|
384
|
+
for n in q.graphs:
|
|
385
|
+
next_level.append(n) # noqa: PERF402
|
|
386
|
+
q.t.CopyFrom(convert_tensor_float_to_float16(q.t, min_positive_val, max_finite_val))
|
|
387
|
+
for n in q.tensors:
|
|
388
|
+
n = convert_tensor_float_to_float16(n, min_positive_val, max_finite_val) # noqa: PLW2901
|
|
389
|
+
# if q is graph, process input, output and value_info (ValueInfoProto)
|
|
390
|
+
if isinstance(q, GraphProto):
|
|
391
|
+
# Note that float initializers tracked by fp32_initializers will be processed later.
|
|
392
|
+
# for all ValueInfoProto with tensor(float) type in input, output and value_info, convert them to
|
|
393
|
+
# tensor(float16) except map and seq(map). And save them in value_info_list for further processing
|
|
394
|
+
for n in itertools.chain(q.input, q.output, q.value_info):
|
|
395
|
+
if n.type.tensor_type.elem_type == TensorProto.FLOAT:
|
|
396
|
+
if n.name not in graph_io_to_skip:
|
|
397
|
+
n.type.tensor_type.elem_type = TensorProto.FLOAT16
|
|
398
|
+
value_info_list.append(n)
|
|
399
|
+
if n.type.HasField("sequence_type"):
|
|
400
|
+
if n.type.sequence_type.elem_type.tensor_type.elem_type == TensorProto.FLOAT:
|
|
401
|
+
if n.name not in graph_io_to_skip:
|
|
402
|
+
n.type.sequence_type.elem_type.tensor_type.elem_type = TensorProto.FLOAT16
|
|
403
|
+
value_info_list.append(n)
|
|
404
|
+
|
|
405
|
+
queue = next_level
|
|
406
|
+
|
|
407
|
+
for value in fp32_initializers.values():
|
|
408
|
+
# By default, to avoid precision loss, do not convert an initializer to fp16 when it is used only by fp32 nodes.
|
|
409
|
+
if force_fp16_initializers or value.fp16_nodes:
|
|
410
|
+
value.initializer = convert_tensor_float_to_float16(value.initializer, min_positive_val, max_finite_val)
|
|
411
|
+
value_info_list.append(make_value_info_from_tensor(value.initializer))
|
|
412
|
+
if value.fp32_nodes and not force_fp16_initializers:
|
|
413
|
+
logger.info(
|
|
414
|
+
f"initializer is used by both fp32 and fp16 nodes. Consider add these nodes to block list:{value.fp16_nodes}"
|
|
415
|
+
)
|
|
416
|
+
|
|
417
|
+
# Some operators have data type fixed as float for some input. Add a float16 to float cast for those inputs.
|
|
418
|
+
for node in mixed_float_type_node_list:
|
|
419
|
+
for i, input_name in enumerate(node.input):
|
|
420
|
+
if i not in ALWAYS_FLOAT_INPUTS[node.op_type] or i in force_fp16_inputs_dict.get(node.op_type, []):
|
|
421
|
+
continue
|
|
422
|
+
for value_info in value_info_list:
|
|
423
|
+
if input_name == value_info.name:
|
|
424
|
+
# create new value_info for current node's new input name
|
|
425
|
+
new_value_info = model.graph.value_info.add()
|
|
426
|
+
new_value_info.CopyFrom(value_info)
|
|
427
|
+
output_name = node.name + "_input_cast_" + str(i)
|
|
428
|
+
new_value_info.name = output_name
|
|
429
|
+
new_value_info.type.tensor_type.elem_type = TensorProto.FLOAT
|
|
430
|
+
# add Cast node (from tensor(float16) to tensor(float) before current node
|
|
431
|
+
node_name = node.name + "_input_cast" + str(i)
|
|
432
|
+
new_node = [helper.make_node("Cast", [input_name], [output_name], to=1, name=node_name)]
|
|
433
|
+
model.graph.node.extend(new_node)
|
|
434
|
+
# change current node's input name
|
|
435
|
+
node.input[i] = output_name
|
|
436
|
+
break
|
|
437
|
+
|
|
438
|
+
accuracy_type = TensorProto.BFLOAT16 if use_bfloat16_as_blocked_nodes_dtype else TensorProto.FLOAT
|
|
439
|
+
# process the nodes in block list that doesn't support tensor(float16)
|
|
440
|
+
for node in node_list:
|
|
441
|
+
# if input's name is in the value_info_list meaning input is tensor(float16) type,
|
|
442
|
+
# insert a float16 to float Cast node before the node,
|
|
443
|
+
# change current node's input name and create new value_info for the new name
|
|
444
|
+
for i in range(len(node.input)):
|
|
445
|
+
input_name = node.input[i]
|
|
446
|
+
for value_info in value_info_list:
|
|
447
|
+
if input_name == value_info.name:
|
|
448
|
+
# create new value_info for current node's new input name
|
|
449
|
+
new_value_info = model.graph.value_info.add()
|
|
450
|
+
new_value_info.CopyFrom(value_info)
|
|
451
|
+
output_name = node.name + "_input_cast_" + str(i)
|
|
452
|
+
new_value_info.name = output_name
|
|
453
|
+
new_value_info.type.tensor_type.elem_type = accuracy_type
|
|
454
|
+
# add Cast node (from tensor(float16) to tensor(float) before current node
|
|
455
|
+
node_name = node.name + "_input_cast" + str(i)
|
|
456
|
+
new_node = [helper.make_node("Cast", [input_name], [output_name], to=accuracy_type, name=node_name)]
|
|
457
|
+
model.graph.node.extend(new_node)
|
|
458
|
+
# change current node's input name
|
|
459
|
+
node.input[i] = output_name
|
|
460
|
+
break
|
|
461
|
+
# if output's name is in the value_info_list meaning output is tensor(float16) type, insert a float to
|
|
462
|
+
# float16 Cast node after the node, change current node's output name and create new value_info for the new name
|
|
463
|
+
for i in range(len(node.output)):
|
|
464
|
+
output = node.output[i]
|
|
465
|
+
for value_info in value_info_list:
|
|
466
|
+
if output == value_info.name:
|
|
467
|
+
# create new value_info for current node's new output
|
|
468
|
+
new_value_info = model.graph.value_info.add()
|
|
469
|
+
new_value_info.CopyFrom(value_info)
|
|
470
|
+
input_name = node.name + "_output_cast_" + str(i)
|
|
471
|
+
new_value_info.name = input_name
|
|
472
|
+
new_value_info.type.tensor_type.elem_type = accuracy_type
|
|
473
|
+
# add Cast node (from tensor(float) to tensor(float16) after current node
|
|
474
|
+
node_name = node.name + "_output_cast" + str(i)
|
|
475
|
+
new_node = [helper.make_node("Cast", [input_name], [output], to=10, name=node_name)]
|
|
476
|
+
model.graph.node.extend(new_node)
|
|
477
|
+
# change current node's input name
|
|
478
|
+
node.output[i] = input_name
|
|
479
|
+
break
|
|
480
|
+
return model
|
|
481
|
+
|
|
482
|
+
|
|
483
|
+
def float_to_float16_max_diff(tensor, min_positive_val=5.96e-08, max_finite_val=65504.0):
|
|
484
|
+
"""Measure the maximum absolute difference after converting a float tensor to float16."""
|
|
485
|
+
if not isinstance(tensor, TensorProto):
|
|
486
|
+
raise ValueError(f"Expected input type is an ONNX TensorProto but got {type(tensor)}")
|
|
487
|
+
if tensor.data_type != TensorProto.FLOAT:
|
|
488
|
+
raise ValueError("Expected tensor data type is float.")
|
|
489
|
+
|
|
490
|
+
float32_data = None
|
|
491
|
+
if tensor.float_data:
|
|
492
|
+
float32_data = np.array(tensor.float_data)
|
|
493
|
+
|
|
494
|
+
if tensor.raw_data:
|
|
495
|
+
float32_data = np.frombuffer(tensor.raw_data, dtype="float32")
|
|
496
|
+
|
|
497
|
+
if float32_data is None:
|
|
498
|
+
raise RuntimeError("external data not loaded!")
|
|
499
|
+
|
|
500
|
+
float16_data = convert_np_to_float16(float32_data, min_positive_val, max_finite_val)
|
|
501
|
+
return np.amax(np.abs(float32_data - np.float32(float16_data)))
|