onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxruntime/LICENSE +21 -0
- onnxruntime/Privacy.md +21 -0
- onnxruntime/ThirdPartyNotices.txt +6121 -0
- onnxruntime/__init__.py +418 -0
- onnxruntime/backend/__init__.py +6 -0
- onnxruntime/backend/backend.py +175 -0
- onnxruntime/backend/backend_rep.py +52 -0
- onnxruntime/capi/DirectML.dll +0 -0
- onnxruntime/capi/__init__.py +4 -0
- onnxruntime/capi/_ld_preload.py +7 -0
- onnxruntime/capi/_pybind_state.py +33 -0
- onnxruntime/capi/build_and_package_info.py +2 -0
- onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
- onnxruntime/capi/onnxruntime.dll +0 -0
- onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
- onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
- onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
- onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
- onnxruntime/capi/onnxruntime_validation.py +154 -0
- onnxruntime/capi/version_info.py +2 -0
- onnxruntime/datasets/__init__.py +18 -0
- onnxruntime/datasets/logreg_iris.onnx +0 -0
- onnxruntime/datasets/mul_1.onnx +0 -0
- onnxruntime/datasets/sigmoid.onnx +13 -0
- onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
- onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
- onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
- onnxruntime/quantization/__init__.py +19 -0
- onnxruntime/quantization/base_quantizer.py +529 -0
- onnxruntime/quantization/calibrate.py +1267 -0
- onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
- onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
- onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
- onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
- onnxruntime/quantization/fusions/__init__.py +4 -0
- onnxruntime/quantization/fusions/fusion.py +311 -0
- onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
- onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
- onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
- onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
- onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
- onnxruntime/quantization/neural_compressor/__init__.py +1 -0
- onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
- onnxruntime/quantization/neural_compressor/util.py +80 -0
- onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
- onnxruntime/quantization/onnx_model.py +600 -0
- onnxruntime/quantization/onnx_quantizer.py +1163 -0
- onnxruntime/quantization/operators/__init__.py +2 -0
- onnxruntime/quantization/operators/activation.py +119 -0
- onnxruntime/quantization/operators/argmax.py +18 -0
- onnxruntime/quantization/operators/attention.py +73 -0
- onnxruntime/quantization/operators/base_operator.py +26 -0
- onnxruntime/quantization/operators/binary_op.py +72 -0
- onnxruntime/quantization/operators/concat.py +62 -0
- onnxruntime/quantization/operators/conv.py +260 -0
- onnxruntime/quantization/operators/direct_q8.py +78 -0
- onnxruntime/quantization/operators/embed_layernorm.py +121 -0
- onnxruntime/quantization/operators/gather.py +64 -0
- onnxruntime/quantization/operators/gavgpool.py +62 -0
- onnxruntime/quantization/operators/gemm.py +172 -0
- onnxruntime/quantization/operators/lstm.py +121 -0
- onnxruntime/quantization/operators/matmul.py +231 -0
- onnxruntime/quantization/operators/maxpool.py +34 -0
- onnxruntime/quantization/operators/norm.py +40 -0
- onnxruntime/quantization/operators/pad.py +172 -0
- onnxruntime/quantization/operators/pooling.py +67 -0
- onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
- onnxruntime/quantization/operators/resize.py +34 -0
- onnxruntime/quantization/operators/softmax.py +74 -0
- onnxruntime/quantization/operators/split.py +63 -0
- onnxruntime/quantization/operators/where.py +87 -0
- onnxruntime/quantization/preprocess.py +141 -0
- onnxruntime/quantization/qdq_loss_debug.py +389 -0
- onnxruntime/quantization/qdq_quantizer.py +1477 -0
- onnxruntime/quantization/quant_utils.py +1051 -0
- onnxruntime/quantization/quantize.py +953 -0
- onnxruntime/quantization/registry.py +110 -0
- onnxruntime/quantization/shape_inference.py +204 -0
- onnxruntime/quantization/static_quantize_runner.py +256 -0
- onnxruntime/quantization/tensor_quant_overrides.py +520 -0
- onnxruntime/tools/__init__.py +10 -0
- onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
- onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
- onnxruntime/tools/file_utils.py +47 -0
- onnxruntime/tools/logger.py +11 -0
- onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
- onnxruntime/tools/mobile_helpers/__init__.py +0 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
- onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
- onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
- onnxruntime/tools/offline_tuning.py +169 -0
- onnxruntime/tools/onnx_model_utils.py +416 -0
- onnxruntime/tools/onnx_randomizer.py +85 -0
- onnxruntime/tools/onnxruntime_test.py +164 -0
- onnxruntime/tools/optimize_onnx_model.py +56 -0
- onnxruntime/tools/ort_format_model/__init__.py +27 -0
- onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
- onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
- onnxruntime/tools/ort_format_model/types.py +85 -0
- onnxruntime/tools/ort_format_model/utils.py +61 -0
- onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
- onnxruntime/tools/pytorch_export_helpers.py +131 -0
- onnxruntime/tools/qdq_helpers/__init__.py +0 -0
- onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
- onnxruntime/tools/qnn/add_trans_cast.py +292 -0
- onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
- onnxruntime/tools/qnn/preprocess.py +165 -0
- onnxruntime/tools/reduced_build_config_parser.py +203 -0
- onnxruntime/tools/remove_initializer_from_input.py +37 -0
- onnxruntime/tools/symbolic_shape_infer.py +3094 -0
- onnxruntime/tools/update_onnx_opset.py +31 -0
- onnxruntime/transformers/__init__.py +8 -0
- onnxruntime/transformers/affinity_helper.py +40 -0
- onnxruntime/transformers/benchmark.py +942 -0
- onnxruntime/transformers/benchmark_helper.py +643 -0
- onnxruntime/transformers/bert_perf_test.py +629 -0
- onnxruntime/transformers/bert_test_data.py +641 -0
- onnxruntime/transformers/compare_bert_results.py +256 -0
- onnxruntime/transformers/constants.py +47 -0
- onnxruntime/transformers/convert_generation.py +3605 -0
- onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
- onnxruntime/transformers/convert_to_packing_mode.py +385 -0
- onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
- onnxruntime/transformers/float16.py +501 -0
- onnxruntime/transformers/fusion_attention.py +1189 -0
- onnxruntime/transformers/fusion_attention_clip.py +340 -0
- onnxruntime/transformers/fusion_attention_sam2.py +533 -0
- onnxruntime/transformers/fusion_attention_unet.py +1307 -0
- onnxruntime/transformers/fusion_attention_vae.py +300 -0
- onnxruntime/transformers/fusion_bart_attention.py +435 -0
- onnxruntime/transformers/fusion_base.py +141 -0
- onnxruntime/transformers/fusion_bias_add.py +57 -0
- onnxruntime/transformers/fusion_biasgelu.py +66 -0
- onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
- onnxruntime/transformers/fusion_conformer_attention.py +222 -0
- onnxruntime/transformers/fusion_constant_fold.py +144 -0
- onnxruntime/transformers/fusion_embedlayer.py +810 -0
- onnxruntime/transformers/fusion_fastgelu.py +492 -0
- onnxruntime/transformers/fusion_gelu.py +258 -0
- onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
- onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
- onnxruntime/transformers/fusion_gpt_attention.py +546 -0
- onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
- onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
- onnxruntime/transformers/fusion_group_norm.py +180 -0
- onnxruntime/transformers/fusion_layernorm.py +489 -0
- onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
- onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
- onnxruntime/transformers/fusion_options.py +340 -0
- onnxruntime/transformers/fusion_qordered_attention.py +420 -0
- onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
- onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
- onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
- onnxruntime/transformers/fusion_quickgelu.py +74 -0
- onnxruntime/transformers/fusion_reshape.py +173 -0
- onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
- onnxruntime/transformers/fusion_shape.py +109 -0
- onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
- onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
- onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
- onnxruntime/transformers/fusion_transpose.py +167 -0
- onnxruntime/transformers/fusion_utils.py +321 -0
- onnxruntime/transformers/huggingface_models.py +74 -0
- onnxruntime/transformers/import_utils.py +20 -0
- onnxruntime/transformers/io_binding_helper.py +487 -0
- onnxruntime/transformers/large_model_exporter.py +395 -0
- onnxruntime/transformers/machine_info.py +230 -0
- onnxruntime/transformers/metrics.py +163 -0
- onnxruntime/transformers/models/bart/__init__.py +12 -0
- onnxruntime/transformers/models/bart/export.py +98 -0
- onnxruntime/transformers/models/bert/__init__.py +12 -0
- onnxruntime/transformers/models/bert/eval_squad.py +329 -0
- onnxruntime/transformers/models/gpt2/__init__.py +12 -0
- onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
- onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
- onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
- onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
- onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
- onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
- onnxruntime/transformers/models/llama/__init__.py +12 -0
- onnxruntime/transformers/models/llama/benchmark.py +700 -0
- onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
- onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
- onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
- onnxruntime/transformers/models/llama/dist_settings.py +57 -0
- onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
- onnxruntime/transformers/models/llama/llama_parity.py +343 -0
- onnxruntime/transformers/models/llama/llama_torch.py +47 -0
- onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
- onnxruntime/transformers/models/longformer/__init__.py +12 -0
- onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
- onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
- onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
- onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
- onnxruntime/transformers/models/phi2/__init__.py +12 -0
- onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
- onnxruntime/transformers/models/phi2/inference_example.py +414 -0
- onnxruntime/transformers/models/sam2/__init__.py +12 -0
- onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
- onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
- onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
- onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
- onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
- onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
- onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
- onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
- onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
- onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
- onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
- onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
- onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
- onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
- onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
- onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
- onnxruntime/transformers/models/t5/__init__.py +12 -0
- onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
- onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
- onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
- onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
- onnxruntime/transformers/models/t5/t5_helper.py +302 -0
- onnxruntime/transformers/models/whisper/__init__.py +12 -0
- onnxruntime/transformers/models/whisper/benchmark.py +585 -0
- onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
- onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
- onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
- onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
- onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
- onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
- onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
- onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
- onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
- onnxruntime/transformers/onnx_exporter.py +719 -0
- onnxruntime/transformers/onnx_model.py +1636 -0
- onnxruntime/transformers/onnx_model_bart.py +141 -0
- onnxruntime/transformers/onnx_model_bert.py +488 -0
- onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
- onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
- onnxruntime/transformers/onnx_model_clip.py +42 -0
- onnxruntime/transformers/onnx_model_conformer.py +32 -0
- onnxruntime/transformers/onnx_model_gpt2.py +101 -0
- onnxruntime/transformers/onnx_model_mmdit.py +112 -0
- onnxruntime/transformers/onnx_model_phi.py +929 -0
- onnxruntime/transformers/onnx_model_sam2.py +137 -0
- onnxruntime/transformers/onnx_model_t5.py +985 -0
- onnxruntime/transformers/onnx_model_tnlr.py +226 -0
- onnxruntime/transformers/onnx_model_unet.py +258 -0
- onnxruntime/transformers/onnx_model_vae.py +42 -0
- onnxruntime/transformers/onnx_utils.py +55 -0
- onnxruntime/transformers/optimizer.py +620 -0
- onnxruntime/transformers/past_helper.py +149 -0
- onnxruntime/transformers/profile_result_processor.py +358 -0
- onnxruntime/transformers/profiler.py +434 -0
- onnxruntime/transformers/quantize_helper.py +76 -0
- onnxruntime/transformers/shape_infer_helper.py +121 -0
- onnxruntime/transformers/shape_optimizer.py +400 -0
- onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
- onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
- onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
- onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
- onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
- onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
# automatically generated by the FlatBuffers compiler, do not modify
|
|
2
|
+
|
|
3
|
+
# namespace: fbs
|
|
4
|
+
|
|
5
|
+
import flatbuffers
|
|
6
|
+
from flatbuffers.compat import import_numpy
|
|
7
|
+
np = import_numpy()
|
|
8
|
+
|
|
9
|
+
class TensorTypeAndShape(object):
|
|
10
|
+
__slots__ = ['_tab']
|
|
11
|
+
|
|
12
|
+
@classmethod
|
|
13
|
+
def GetRootAs(cls, buf, offset=0):
|
|
14
|
+
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
|
|
15
|
+
x = TensorTypeAndShape()
|
|
16
|
+
x.Init(buf, n + offset)
|
|
17
|
+
return x
|
|
18
|
+
|
|
19
|
+
@classmethod
|
|
20
|
+
def GetRootAsTensorTypeAndShape(cls, buf, offset=0):
|
|
21
|
+
"""This method is deprecated. Please switch to GetRootAs."""
|
|
22
|
+
return cls.GetRootAs(buf, offset)
|
|
23
|
+
@classmethod
|
|
24
|
+
def TensorTypeAndShapeBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
|
|
25
|
+
return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x4F\x52\x54\x4D", size_prefixed=size_prefixed)
|
|
26
|
+
|
|
27
|
+
# TensorTypeAndShape
|
|
28
|
+
def Init(self, buf, pos):
|
|
29
|
+
self._tab = flatbuffers.table.Table(buf, pos)
|
|
30
|
+
|
|
31
|
+
# TensorTypeAndShape
|
|
32
|
+
def ElemType(self):
|
|
33
|
+
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
|
|
34
|
+
if o != 0:
|
|
35
|
+
return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos)
|
|
36
|
+
return 0
|
|
37
|
+
|
|
38
|
+
# TensorTypeAndShape
|
|
39
|
+
def Shape(self):
|
|
40
|
+
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
|
|
41
|
+
if o != 0:
|
|
42
|
+
x = self._tab.Indirect(o + self._tab.Pos)
|
|
43
|
+
from ort_flatbuffers_py.fbs.Shape import Shape
|
|
44
|
+
obj = Shape()
|
|
45
|
+
obj.Init(self._tab.Bytes, x)
|
|
46
|
+
return obj
|
|
47
|
+
return None
|
|
48
|
+
|
|
49
|
+
def TensorTypeAndShapeStart(builder):
|
|
50
|
+
builder.StartObject(2)
|
|
51
|
+
|
|
52
|
+
def Start(builder):
|
|
53
|
+
TensorTypeAndShapeStart(builder)
|
|
54
|
+
|
|
55
|
+
def TensorTypeAndShapeAddElemType(builder, elemType):
|
|
56
|
+
builder.PrependInt32Slot(0, elemType, 0)
|
|
57
|
+
|
|
58
|
+
def AddElemType(builder, elemType):
|
|
59
|
+
TensorTypeAndShapeAddElemType(builder, elemType)
|
|
60
|
+
|
|
61
|
+
def TensorTypeAndShapeAddShape(builder, shape):
|
|
62
|
+
builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(shape), 0)
|
|
63
|
+
|
|
64
|
+
def AddShape(builder, shape):
|
|
65
|
+
TensorTypeAndShapeAddShape(builder, shape)
|
|
66
|
+
|
|
67
|
+
def TensorTypeAndShapeEnd(builder):
|
|
68
|
+
return builder.EndObject()
|
|
69
|
+
|
|
70
|
+
def End(builder):
|
|
71
|
+
return TensorTypeAndShapeEnd(builder)
|
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
# automatically generated by the FlatBuffers compiler, do not modify
|
|
2
|
+
|
|
3
|
+
# namespace: fbs
|
|
4
|
+
|
|
5
|
+
import flatbuffers
|
|
6
|
+
from flatbuffers.compat import import_numpy
|
|
7
|
+
np = import_numpy()
|
|
8
|
+
|
|
9
|
+
class TypeInfo(object):
|
|
10
|
+
__slots__ = ['_tab']
|
|
11
|
+
|
|
12
|
+
@classmethod
|
|
13
|
+
def GetRootAs(cls, buf, offset=0):
|
|
14
|
+
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
|
|
15
|
+
x = TypeInfo()
|
|
16
|
+
x.Init(buf, n + offset)
|
|
17
|
+
return x
|
|
18
|
+
|
|
19
|
+
@classmethod
|
|
20
|
+
def GetRootAsTypeInfo(cls, buf, offset=0):
|
|
21
|
+
"""This method is deprecated. Please switch to GetRootAs."""
|
|
22
|
+
return cls.GetRootAs(buf, offset)
|
|
23
|
+
@classmethod
|
|
24
|
+
def TypeInfoBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
|
|
25
|
+
return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x4F\x52\x54\x4D", size_prefixed=size_prefixed)
|
|
26
|
+
|
|
27
|
+
# TypeInfo
|
|
28
|
+
def Init(self, buf, pos):
|
|
29
|
+
self._tab = flatbuffers.table.Table(buf, pos)
|
|
30
|
+
|
|
31
|
+
# TypeInfo
|
|
32
|
+
def Denotation(self):
|
|
33
|
+
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
|
|
34
|
+
if o != 0:
|
|
35
|
+
return self._tab.String(o + self._tab.Pos)
|
|
36
|
+
return None
|
|
37
|
+
|
|
38
|
+
# TypeInfo
|
|
39
|
+
def ValueType(self):
|
|
40
|
+
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
|
|
41
|
+
if o != 0:
|
|
42
|
+
return self._tab.Get(flatbuffers.number_types.Uint8Flags, o + self._tab.Pos)
|
|
43
|
+
return 0
|
|
44
|
+
|
|
45
|
+
# TypeInfo
|
|
46
|
+
def Value(self):
|
|
47
|
+
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
|
|
48
|
+
if o != 0:
|
|
49
|
+
from flatbuffers.table import Table
|
|
50
|
+
obj = Table(bytearray(), 0)
|
|
51
|
+
self._tab.Union(obj, o)
|
|
52
|
+
return obj
|
|
53
|
+
return None
|
|
54
|
+
|
|
55
|
+
def TypeInfoStart(builder):
|
|
56
|
+
builder.StartObject(3)
|
|
57
|
+
|
|
58
|
+
def Start(builder):
|
|
59
|
+
TypeInfoStart(builder)
|
|
60
|
+
|
|
61
|
+
def TypeInfoAddDenotation(builder, denotation):
|
|
62
|
+
builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(denotation), 0)
|
|
63
|
+
|
|
64
|
+
def AddDenotation(builder, denotation):
|
|
65
|
+
TypeInfoAddDenotation(builder, denotation)
|
|
66
|
+
|
|
67
|
+
def TypeInfoAddValueType(builder, valueType):
|
|
68
|
+
builder.PrependUint8Slot(1, valueType, 0)
|
|
69
|
+
|
|
70
|
+
def AddValueType(builder, valueType):
|
|
71
|
+
TypeInfoAddValueType(builder, valueType)
|
|
72
|
+
|
|
73
|
+
def TypeInfoAddValue(builder, value):
|
|
74
|
+
builder.PrependUOffsetTRelativeSlot(2, flatbuffers.number_types.UOffsetTFlags.py_type(value), 0)
|
|
75
|
+
|
|
76
|
+
def AddValue(builder, value):
|
|
77
|
+
TypeInfoAddValue(builder, value)
|
|
78
|
+
|
|
79
|
+
def TypeInfoEnd(builder):
|
|
80
|
+
return builder.EndObject()
|
|
81
|
+
|
|
82
|
+
def End(builder):
|
|
83
|
+
return TypeInfoEnd(builder)
|
|
@@ -0,0 +1,84 @@
|
|
|
1
|
+
# automatically generated by the FlatBuffers compiler, do not modify
|
|
2
|
+
|
|
3
|
+
# namespace: fbs
|
|
4
|
+
|
|
5
|
+
import flatbuffers
|
|
6
|
+
from flatbuffers.compat import import_numpy
|
|
7
|
+
np = import_numpy()
|
|
8
|
+
|
|
9
|
+
class ValueInfo(object):
|
|
10
|
+
__slots__ = ['_tab']
|
|
11
|
+
|
|
12
|
+
@classmethod
|
|
13
|
+
def GetRootAs(cls, buf, offset=0):
|
|
14
|
+
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
|
|
15
|
+
x = ValueInfo()
|
|
16
|
+
x.Init(buf, n + offset)
|
|
17
|
+
return x
|
|
18
|
+
|
|
19
|
+
@classmethod
|
|
20
|
+
def GetRootAsValueInfo(cls, buf, offset=0):
|
|
21
|
+
"""This method is deprecated. Please switch to GetRootAs."""
|
|
22
|
+
return cls.GetRootAs(buf, offset)
|
|
23
|
+
@classmethod
|
|
24
|
+
def ValueInfoBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
|
|
25
|
+
return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x4F\x52\x54\x4D", size_prefixed=size_prefixed)
|
|
26
|
+
|
|
27
|
+
# ValueInfo
|
|
28
|
+
def Init(self, buf, pos):
|
|
29
|
+
self._tab = flatbuffers.table.Table(buf, pos)
|
|
30
|
+
|
|
31
|
+
# ValueInfo
|
|
32
|
+
def Name(self):
|
|
33
|
+
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
|
|
34
|
+
if o != 0:
|
|
35
|
+
return self._tab.String(o + self._tab.Pos)
|
|
36
|
+
return None
|
|
37
|
+
|
|
38
|
+
# ValueInfo
|
|
39
|
+
def DocString(self):
|
|
40
|
+
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
|
|
41
|
+
if o != 0:
|
|
42
|
+
return self._tab.String(o + self._tab.Pos)
|
|
43
|
+
return None
|
|
44
|
+
|
|
45
|
+
# ValueInfo
|
|
46
|
+
def Type(self):
|
|
47
|
+
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
|
|
48
|
+
if o != 0:
|
|
49
|
+
x = self._tab.Indirect(o + self._tab.Pos)
|
|
50
|
+
from ort_flatbuffers_py.fbs.TypeInfo import TypeInfo
|
|
51
|
+
obj = TypeInfo()
|
|
52
|
+
obj.Init(self._tab.Bytes, x)
|
|
53
|
+
return obj
|
|
54
|
+
return None
|
|
55
|
+
|
|
56
|
+
def ValueInfoStart(builder):
|
|
57
|
+
builder.StartObject(3)
|
|
58
|
+
|
|
59
|
+
def Start(builder):
|
|
60
|
+
ValueInfoStart(builder)
|
|
61
|
+
|
|
62
|
+
def ValueInfoAddName(builder, name):
|
|
63
|
+
builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(name), 0)
|
|
64
|
+
|
|
65
|
+
def AddName(builder, name):
|
|
66
|
+
ValueInfoAddName(builder, name)
|
|
67
|
+
|
|
68
|
+
def ValueInfoAddDocString(builder, docString):
|
|
69
|
+
builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(docString), 0)
|
|
70
|
+
|
|
71
|
+
def AddDocString(builder, docString):
|
|
72
|
+
ValueInfoAddDocString(builder, docString)
|
|
73
|
+
|
|
74
|
+
def ValueInfoAddType(builder, type):
|
|
75
|
+
builder.PrependUOffsetTRelativeSlot(2, flatbuffers.number_types.UOffsetTFlags.py_type(type), 0)
|
|
76
|
+
|
|
77
|
+
def AddType(builder, type):
|
|
78
|
+
ValueInfoAddType(builder, type)
|
|
79
|
+
|
|
80
|
+
def ValueInfoEnd(builder):
|
|
81
|
+
return builder.EndObject()
|
|
82
|
+
|
|
83
|
+
def End(builder):
|
|
84
|
+
return ValueInfoEnd(builder)
|
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
import ort_flatbuffers_py.fbs as fbs
|
|
5
|
+
|
|
6
|
+
from .operator_type_usage_processors import OperatorTypeUsageManager
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class OrtFormatModelProcessor:
|
|
10
|
+
"Class to process an ORT format model and determine required operators and types."
|
|
11
|
+
|
|
12
|
+
def __init__(self, model_path: str, required_ops: dict, processors: OperatorTypeUsageManager):
|
|
13
|
+
"""
|
|
14
|
+
Initialize ORT format model processor
|
|
15
|
+
:param model_path: Path to model to load
|
|
16
|
+
:param required_ops: Dictionary required operator information will be added to.
|
|
17
|
+
:param processors: Operator type usage processors which will be called for each matching Node.
|
|
18
|
+
"""
|
|
19
|
+
self._required_ops = required_ops # dictionary of {domain: {opset:[operators]}}
|
|
20
|
+
self._file = open(model_path, "rb").read() # noqa: SIM115
|
|
21
|
+
self._buffer = bytearray(self._file)
|
|
22
|
+
if not fbs.InferenceSession.InferenceSession.InferenceSessionBufferHasIdentifier(self._buffer, 0):
|
|
23
|
+
raise RuntimeError(f"File does not appear to be a valid ORT format model: '{model_path}'")
|
|
24
|
+
self._model = fbs.InferenceSession.InferenceSession.GetRootAsInferenceSession(self._buffer, 0).Model()
|
|
25
|
+
self._op_type_processors = processors
|
|
26
|
+
|
|
27
|
+
@staticmethod
|
|
28
|
+
def _setup_type_info(graph: fbs.Graph, outer_scope_value_typeinfo={}): # noqa: B006
|
|
29
|
+
"""
|
|
30
|
+
Setup the node args for this level of Graph.
|
|
31
|
+
We copy the current list which represents the outer scope values, and add the local node args to that
|
|
32
|
+
to create the valid list of values for the current Graph.
|
|
33
|
+
:param graph: Graph to create NodeArg list for
|
|
34
|
+
:param outer_scope_value_typeinfo: TypeInfo for outer scope values. Empty for the top-level graph in a model.
|
|
35
|
+
:return: Dictionary of NodeArg name to TypeInfo
|
|
36
|
+
"""
|
|
37
|
+
value_name_to_typeinfo = outer_scope_value_typeinfo.copy()
|
|
38
|
+
for j in range(graph.NodeArgsLength()):
|
|
39
|
+
n = graph.NodeArgs(j)
|
|
40
|
+
value_name_to_typeinfo[n.Name()] = n.Type() # TypeInfo for this NodeArg's name
|
|
41
|
+
|
|
42
|
+
return value_name_to_typeinfo
|
|
43
|
+
|
|
44
|
+
def _add_required_op(self, domain: str, opset: int, op_type: str):
|
|
45
|
+
if domain not in self._required_ops:
|
|
46
|
+
self._required_ops[domain] = {opset: {op_type}}
|
|
47
|
+
elif opset not in self._required_ops[domain]:
|
|
48
|
+
self._required_ops[domain][opset] = {op_type}
|
|
49
|
+
else:
|
|
50
|
+
self._required_ops[domain][opset].add(op_type)
|
|
51
|
+
|
|
52
|
+
def _process_graph(self, graph: fbs.Graph, outer_scope_value_typeinfo: dict):
|
|
53
|
+
"""
|
|
54
|
+
Process one level of the Graph, descending into any subgraphs when they are found
|
|
55
|
+
:param outer_scope_value_typeinfo: Outer scope NodeArg dictionary from ancestor graphs
|
|
56
|
+
"""
|
|
57
|
+
# Merge the TypeInfo for all values in this level of the graph with the outer scope value TypeInfo.
|
|
58
|
+
value_name_to_typeinfo = OrtFormatModelProcessor._setup_type_info(graph, outer_scope_value_typeinfo)
|
|
59
|
+
|
|
60
|
+
for i in range(graph.NodesLength()):
|
|
61
|
+
node = graph.Nodes(i)
|
|
62
|
+
|
|
63
|
+
optype = node.OpType().decode()
|
|
64
|
+
domain = node.Domain().decode() or "ai.onnx" # empty domain defaults to ai.onnx
|
|
65
|
+
|
|
66
|
+
self._add_required_op(domain, node.SinceVersion(), optype)
|
|
67
|
+
|
|
68
|
+
if self._op_type_processors:
|
|
69
|
+
self._op_type_processors.process_node(node, value_name_to_typeinfo)
|
|
70
|
+
|
|
71
|
+
# Read all the attributes
|
|
72
|
+
for j in range(node.AttributesLength()):
|
|
73
|
+
attr = node.Attributes(j)
|
|
74
|
+
attr_type = attr.Type()
|
|
75
|
+
if attr_type == fbs.AttributeType.AttributeType.GRAPH:
|
|
76
|
+
self._process_graph(attr.G(), value_name_to_typeinfo)
|
|
77
|
+
elif attr_type == fbs.AttributeType.AttributeType.GRAPHS:
|
|
78
|
+
# the ONNX spec doesn't currently define any operators that have multiple graphs in an attribute
|
|
79
|
+
# so entering this 'elif' isn't currently possible
|
|
80
|
+
for k in range(attr.GraphsLength()):
|
|
81
|
+
self._process_graph(attr.Graphs(k), value_name_to_typeinfo)
|
|
82
|
+
|
|
83
|
+
def process(self):
|
|
84
|
+
graph = self._model.Graph()
|
|
85
|
+
outer_scope_value_typeinfo = {} # no outer scope values for the main graph
|
|
86
|
+
self._process_graph(graph, outer_scope_value_typeinfo)
|
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
import ort_flatbuffers_py.fbs as fbs
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class FbsTypeInfo:
|
|
8
|
+
"Class to provide conversion between ORT flatbuffers schema values and C++ types"
|
|
9
|
+
|
|
10
|
+
tensordatatype_to_string = { # noqa: RUF012
|
|
11
|
+
fbs.TensorDataType.TensorDataType.FLOAT: "float",
|
|
12
|
+
fbs.TensorDataType.TensorDataType.UINT8: "uint8_t",
|
|
13
|
+
fbs.TensorDataType.TensorDataType.INT8: "int8_t",
|
|
14
|
+
fbs.TensorDataType.TensorDataType.UINT16: "uint16_t",
|
|
15
|
+
fbs.TensorDataType.TensorDataType.INT16: "int16_t",
|
|
16
|
+
fbs.TensorDataType.TensorDataType.INT32: "int32_t",
|
|
17
|
+
fbs.TensorDataType.TensorDataType.INT64: "int64_t",
|
|
18
|
+
fbs.TensorDataType.TensorDataType.STRING: "std::string",
|
|
19
|
+
fbs.TensorDataType.TensorDataType.BOOL: "bool",
|
|
20
|
+
fbs.TensorDataType.TensorDataType.FLOAT16: "MLFloat16",
|
|
21
|
+
fbs.TensorDataType.TensorDataType.DOUBLE: "double",
|
|
22
|
+
fbs.TensorDataType.TensorDataType.UINT32: "uint32_t",
|
|
23
|
+
fbs.TensorDataType.TensorDataType.UINT64: "uint64_t",
|
|
24
|
+
# fbs.TensorDataType.TensorDataType.COMPLEX64: 'complex64 is not supported',
|
|
25
|
+
# fbs.TensorDataType.TensorDataType.COMPLEX128: 'complex128 is not supported',
|
|
26
|
+
fbs.TensorDataType.TensorDataType.BFLOAT16: "BFloat16",
|
|
27
|
+
fbs.TensorDataType.TensorDataType.FLOAT8E4M3FN: "Float8E4M3FN",
|
|
28
|
+
fbs.TensorDataType.TensorDataType.FLOAT8E4M3FNUZ: "Float8E4M3FNUZ",
|
|
29
|
+
fbs.TensorDataType.TensorDataType.FLOAT8E5M2: "Float8E5M2",
|
|
30
|
+
fbs.TensorDataType.TensorDataType.FLOAT8E5M2FNUZ: "Float8E5M2FNUZ",
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
@staticmethod
|
|
34
|
+
def typeinfo_to_str(type: fbs.TypeInfo):
|
|
35
|
+
value_type = type.ValueType()
|
|
36
|
+
value = type.Value()
|
|
37
|
+
type_str = "unknown"
|
|
38
|
+
|
|
39
|
+
if value_type == fbs.TypeInfoValue.TypeInfoValue.tensor_type:
|
|
40
|
+
tensor_type_and_shape = fbs.TensorTypeAndShape.TensorTypeAndShape()
|
|
41
|
+
tensor_type_and_shape.Init(value.Bytes, value.Pos)
|
|
42
|
+
elem_type = tensor_type_and_shape.ElemType()
|
|
43
|
+
type_str = FbsTypeInfo.tensordatatype_to_string[elem_type]
|
|
44
|
+
|
|
45
|
+
elif value_type == fbs.TypeInfoValue.TypeInfoValue.map_type:
|
|
46
|
+
map_type = fbs.MapType.MapType()
|
|
47
|
+
map_type.init(value.Bytes, value.Pos)
|
|
48
|
+
key_type = map_type.KeyType() # TensorDataType
|
|
49
|
+
key_type_str = FbsTypeInfo.tensordatatype_to_string[key_type]
|
|
50
|
+
value_type = map_type.ValueType() # TypeInfo
|
|
51
|
+
value_type_str = FbsTypeInfo.typeinfo_to_str(value_type)
|
|
52
|
+
type_str = f"std::map<{key_type_str},{value_type_str}>"
|
|
53
|
+
|
|
54
|
+
elif value_type == fbs.TypeInfoValue.TypeInfoValue.sequence_type:
|
|
55
|
+
sequence_type = fbs.SequenceType.SequenceType()
|
|
56
|
+
sequence_type.Init(value.Bytes, value.Pos)
|
|
57
|
+
elem_type = sequence_type.ElemType() # TypeInfo
|
|
58
|
+
elem_type_str = FbsTypeInfo.typeinfo_to_str(elem_type)
|
|
59
|
+
# TODO: Decide if we need to wrap the type in a std::vector. Issue is that the element type is internal
|
|
60
|
+
# to the onnxruntime::Tensor class so we're really returning the type inside the Tensor not vector<Tensor>.
|
|
61
|
+
# For now, return the element type (which will be the Tensor element type, or a map<A,B>) as
|
|
62
|
+
# an operator input or output will either be a sequence or a not, so we don't need to disambiguate
|
|
63
|
+
# between the two (i.e. we know if the returned value refers to the contents of a sequence, and can
|
|
64
|
+
# handle whether it's the element type of a Tensor in the sequence, or the map type in a sequence of maps
|
|
65
|
+
# due to this).
|
|
66
|
+
type_str = elem_type_str
|
|
67
|
+
else:
|
|
68
|
+
raise ValueError(f"Unknown or missing value type of {value_type}")
|
|
69
|
+
|
|
70
|
+
return type_str
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def get_typeinfo(name: str, value_name_to_typeinfo: dict) -> fbs.TypeInfo:
|
|
74
|
+
"Lookup a name in a dictionary mapping value name to TypeInfo."
|
|
75
|
+
if name not in value_name_to_typeinfo:
|
|
76
|
+
raise RuntimeError("Missing TypeInfo entry for " + name)
|
|
77
|
+
|
|
78
|
+
return value_name_to_typeinfo[name] # TypeInfo object
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def value_name_to_typestr(name: str, value_name_to_typeinfo: dict):
|
|
82
|
+
"Lookup TypeInfo for value name and convert to a string representing the C++ type."
|
|
83
|
+
type = get_typeinfo(name, value_name_to_typeinfo)
|
|
84
|
+
type_str = FbsTypeInfo.typeinfo_to_str(type)
|
|
85
|
+
return type_str
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
import pathlib
|
|
5
|
+
import typing
|
|
6
|
+
|
|
7
|
+
from ..logger import get_logger
|
|
8
|
+
from .operator_type_usage_processors import OperatorTypeUsageManager
|
|
9
|
+
from .ort_model_processor import OrtFormatModelProcessor
|
|
10
|
+
|
|
11
|
+
log = get_logger("ort_format_model.utils")
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def _extract_ops_and_types_from_ort_models(model_files: typing.Iterable[pathlib.Path], enable_type_reduction: bool):
|
|
15
|
+
required_ops = {}
|
|
16
|
+
op_type_usage_manager = OperatorTypeUsageManager() if enable_type_reduction else None
|
|
17
|
+
|
|
18
|
+
for model_file in model_files:
|
|
19
|
+
if not model_file.is_file():
|
|
20
|
+
raise ValueError(f"Path is not a file: '{model_file}'")
|
|
21
|
+
model_processor = OrtFormatModelProcessor(str(model_file), required_ops, op_type_usage_manager)
|
|
22
|
+
model_processor.process() # this updates required_ops and op_type_processors
|
|
23
|
+
|
|
24
|
+
return required_ops, op_type_usage_manager
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def create_config_from_models(
|
|
28
|
+
model_files: typing.Iterable[pathlib.Path], output_file: pathlib.Path, enable_type_reduction: bool
|
|
29
|
+
):
|
|
30
|
+
"""
|
|
31
|
+
Create a configuration file with required operators and optionally required types.
|
|
32
|
+
:param model_files: Model files to use to generate the configuration file.
|
|
33
|
+
:param output_file: File to write configuration to.
|
|
34
|
+
:param enable_type_reduction: Include required type information for individual operators in the configuration.
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
required_ops, op_type_processors = _extract_ops_and_types_from_ort_models(model_files, enable_type_reduction)
|
|
38
|
+
|
|
39
|
+
output_file.parent.mkdir(parents=True, exist_ok=True)
|
|
40
|
+
|
|
41
|
+
with open(output_file, "w") as out:
|
|
42
|
+
out.write("# Generated from model/s:\n")
|
|
43
|
+
out.writelines(f"# - {model_file}\n" for model_file in sorted(model_files))
|
|
44
|
+
|
|
45
|
+
for domain in sorted(required_ops.keys()):
|
|
46
|
+
for opset in sorted(required_ops[domain].keys()):
|
|
47
|
+
ops = required_ops[domain][opset]
|
|
48
|
+
if ops:
|
|
49
|
+
out.write(f"{domain};{opset};")
|
|
50
|
+
if enable_type_reduction:
|
|
51
|
+
# type string is empty if op hasn't been seen
|
|
52
|
+
entries = [
|
|
53
|
+
"{}{}".format(op, op_type_processors.get_config_entry(domain, op) or "")
|
|
54
|
+
for op in sorted(ops)
|
|
55
|
+
]
|
|
56
|
+
else:
|
|
57
|
+
entries = sorted(ops)
|
|
58
|
+
|
|
59
|
+
out.write("{}\n".format(",".join(entries)))
|
|
60
|
+
|
|
61
|
+
log.info("Created config in %s", output_file)
|
|
@@ -0,0 +1,129 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""
|
|
5
|
+
Support for registering ONNX Runtime's built-in contrib ops with
|
|
6
|
+
PyTorch-ONNX exporter (torch.onnx.export).
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import typing
|
|
10
|
+
|
|
11
|
+
try:
|
|
12
|
+
# TODO(justinchuby): Create a function to alert users when torch is not installed
|
|
13
|
+
import torch
|
|
14
|
+
except ModuleNotFoundError:
|
|
15
|
+
raise ModuleNotFoundError( # noqa: B904
|
|
16
|
+
"This module is only useful in combination with PyTorch. To install PyTorch see https://pytorch.org/."
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
from torch.onnx import symbolic_helper
|
|
20
|
+
|
|
21
|
+
_OPSET_VERSION = 1
|
|
22
|
+
_registered_ops: typing.AbstractSet[str] = set()
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def _reg(symbolic_fn: typing.Callable, namespace: str = ""):
|
|
26
|
+
name = f"{namespace}::{symbolic_fn.__name__}"
|
|
27
|
+
torch.onnx.register_custom_op_symbolic(name, symbolic_fn, _OPSET_VERSION)
|
|
28
|
+
_registered_ops.add(name)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def register():
|
|
32
|
+
"""Register ONNX Runtime's built-in contrib ops.
|
|
33
|
+
|
|
34
|
+
Should be run before torch.onnx.export().
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
def grid_sampler(g, input, grid, mode, padding_mode, align_corners):
|
|
38
|
+
# mode
|
|
39
|
+
# 'bilinear' : onnx::Constant[value={0}]
|
|
40
|
+
# 'nearest' : onnx::Constant[value={1}]
|
|
41
|
+
# 'bicubic' : onnx::Constant[value={2}]
|
|
42
|
+
# padding_mode
|
|
43
|
+
# 'zeros' : onnx::Constant[value={0}]
|
|
44
|
+
# 'border' : onnx::Constant[value={1}]
|
|
45
|
+
# 'reflection' : onnx::Constant[value={2}]
|
|
46
|
+
mode = symbolic_helper._maybe_get_const(mode, "i")
|
|
47
|
+
padding_mode = symbolic_helper._maybe_get_const(padding_mode, "i")
|
|
48
|
+
mode_str = ["bilinear", "nearest", "bicubic"][mode]
|
|
49
|
+
padding_mode_str = ["zeros", "border", "reflection"][padding_mode]
|
|
50
|
+
align_corners = int(symbolic_helper._maybe_get_const(align_corners, "b"))
|
|
51
|
+
|
|
52
|
+
# From opset v13 onward, the output shape can be specified with
|
|
53
|
+
# (N, C, H, W) (N, H_out, W_out, 2) => (N, C, H_out, W_out)
|
|
54
|
+
# input_shape = input.type().sizes()
|
|
55
|
+
# gird_shape = grid.type().sizes()
|
|
56
|
+
# output_shape = input_shape[:2] + gird_shape[1:3]
|
|
57
|
+
# g.op(...).setType(input.type().with_sizes(output_shape))
|
|
58
|
+
|
|
59
|
+
return g.op(
|
|
60
|
+
"com.microsoft::GridSample",
|
|
61
|
+
input,
|
|
62
|
+
grid,
|
|
63
|
+
mode_s=mode_str,
|
|
64
|
+
padding_mode_s=padding_mode_str,
|
|
65
|
+
align_corners_i=align_corners,
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
_reg(grid_sampler)
|
|
69
|
+
|
|
70
|
+
def inverse(g, self):
|
|
71
|
+
return g.op("com.microsoft::Inverse", self).setType(self.type())
|
|
72
|
+
|
|
73
|
+
_reg(inverse)
|
|
74
|
+
|
|
75
|
+
@torch.onnx.symbolic_helper.parse_args("v", "s")
|
|
76
|
+
def gelu(g, self: torch._C.Value, approximate: str = "none"):
|
|
77
|
+
# Use microsoft::Gelu for performance if possible. It only supports approximate == "none"
|
|
78
|
+
if approximate == "none":
|
|
79
|
+
return g.op("com.microsoft::Gelu", self).setType(self.type())
|
|
80
|
+
return torch.onnx.symbolic_opset9.gelu(g, self, approximate)
|
|
81
|
+
|
|
82
|
+
_reg(gelu)
|
|
83
|
+
|
|
84
|
+
def triu(g, self, diagonal):
|
|
85
|
+
return g.op("com.microsoft::Trilu", self, diagonal, upper_i=1).setType(self.type())
|
|
86
|
+
|
|
87
|
+
_reg(triu)
|
|
88
|
+
|
|
89
|
+
def tril(g, self, diagonal):
|
|
90
|
+
return g.op("com.microsoft::Trilu", self, diagonal, upper_i=0).setType(self.type())
|
|
91
|
+
|
|
92
|
+
_reg(tril)
|
|
93
|
+
|
|
94
|
+
@torch.onnx.symbolic_helper.parse_args("v")
|
|
95
|
+
def DynamicTimeWarping(g, self): # noqa: N802
|
|
96
|
+
return g.op("com.microsoft::DynamicTimeWarping", self)
|
|
97
|
+
|
|
98
|
+
_reg(DynamicTimeWarping, namespace="onnxruntime")
|
|
99
|
+
|
|
100
|
+
def UnfoldTensor(g, self, dim, size, step): # noqa: N802
|
|
101
|
+
dim = int(symbolic_helper._maybe_get_const(dim, "i"))
|
|
102
|
+
size = int(symbolic_helper._maybe_get_const(size, "i"))
|
|
103
|
+
step = int(symbolic_helper._maybe_get_const(step, "i"))
|
|
104
|
+
return g.op(
|
|
105
|
+
"com.microsoft::UnfoldTensor",
|
|
106
|
+
self,
|
|
107
|
+
dim_i=dim,
|
|
108
|
+
size_i=size,
|
|
109
|
+
step_i=step,
|
|
110
|
+
).setType(self.type().with_sizes([None, None, None, None, size]))
|
|
111
|
+
|
|
112
|
+
_reg(UnfoldTensor, namespace="onnxruntime")
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def unregister():
|
|
116
|
+
"""Unregister ONNX Runtime's built-in contrib ops."""
|
|
117
|
+
for name in _registered_ops:
|
|
118
|
+
try:
|
|
119
|
+
torch.onnx.unregister_custom_op_symbolic(name, _OPSET_VERSION)
|
|
120
|
+
except AttributeError:
|
|
121
|
+
# The symbolic_registry module was removed in PyTorch 1.13.
|
|
122
|
+
# We are importing it here for backwards compatibility
|
|
123
|
+
# because unregister_custom_op_symbolic is not available before PyTorch 1.12
|
|
124
|
+
from torch.onnx import symbolic_registry # noqa: PLC0415
|
|
125
|
+
|
|
126
|
+
namespace, kind = name.split("::")
|
|
127
|
+
for version in symbolic_helper._onnx_stable_opsets:
|
|
128
|
+
if version >= _OPSET_VERSION and symbolic_registry.is_registered_op(kind, namespace, version):
|
|
129
|
+
del symbolic_registry._registry[(namespace, version)][kind]
|