onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxruntime/LICENSE +21 -0
- onnxruntime/Privacy.md +21 -0
- onnxruntime/ThirdPartyNotices.txt +6121 -0
- onnxruntime/__init__.py +418 -0
- onnxruntime/backend/__init__.py +6 -0
- onnxruntime/backend/backend.py +175 -0
- onnxruntime/backend/backend_rep.py +52 -0
- onnxruntime/capi/DirectML.dll +0 -0
- onnxruntime/capi/__init__.py +4 -0
- onnxruntime/capi/_ld_preload.py +7 -0
- onnxruntime/capi/_pybind_state.py +33 -0
- onnxruntime/capi/build_and_package_info.py +2 -0
- onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
- onnxruntime/capi/onnxruntime.dll +0 -0
- onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
- onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
- onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
- onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
- onnxruntime/capi/onnxruntime_validation.py +154 -0
- onnxruntime/capi/version_info.py +2 -0
- onnxruntime/datasets/__init__.py +18 -0
- onnxruntime/datasets/logreg_iris.onnx +0 -0
- onnxruntime/datasets/mul_1.onnx +0 -0
- onnxruntime/datasets/sigmoid.onnx +13 -0
- onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
- onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
- onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
- onnxruntime/quantization/__init__.py +19 -0
- onnxruntime/quantization/base_quantizer.py +529 -0
- onnxruntime/quantization/calibrate.py +1267 -0
- onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
- onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
- onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
- onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
- onnxruntime/quantization/fusions/__init__.py +4 -0
- onnxruntime/quantization/fusions/fusion.py +311 -0
- onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
- onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
- onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
- onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
- onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
- onnxruntime/quantization/neural_compressor/__init__.py +1 -0
- onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
- onnxruntime/quantization/neural_compressor/util.py +80 -0
- onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
- onnxruntime/quantization/onnx_model.py +600 -0
- onnxruntime/quantization/onnx_quantizer.py +1163 -0
- onnxruntime/quantization/operators/__init__.py +2 -0
- onnxruntime/quantization/operators/activation.py +119 -0
- onnxruntime/quantization/operators/argmax.py +18 -0
- onnxruntime/quantization/operators/attention.py +73 -0
- onnxruntime/quantization/operators/base_operator.py +26 -0
- onnxruntime/quantization/operators/binary_op.py +72 -0
- onnxruntime/quantization/operators/concat.py +62 -0
- onnxruntime/quantization/operators/conv.py +260 -0
- onnxruntime/quantization/operators/direct_q8.py +78 -0
- onnxruntime/quantization/operators/embed_layernorm.py +121 -0
- onnxruntime/quantization/operators/gather.py +64 -0
- onnxruntime/quantization/operators/gavgpool.py +62 -0
- onnxruntime/quantization/operators/gemm.py +172 -0
- onnxruntime/quantization/operators/lstm.py +121 -0
- onnxruntime/quantization/operators/matmul.py +231 -0
- onnxruntime/quantization/operators/maxpool.py +34 -0
- onnxruntime/quantization/operators/norm.py +40 -0
- onnxruntime/quantization/operators/pad.py +172 -0
- onnxruntime/quantization/operators/pooling.py +67 -0
- onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
- onnxruntime/quantization/operators/resize.py +34 -0
- onnxruntime/quantization/operators/softmax.py +74 -0
- onnxruntime/quantization/operators/split.py +63 -0
- onnxruntime/quantization/operators/where.py +87 -0
- onnxruntime/quantization/preprocess.py +141 -0
- onnxruntime/quantization/qdq_loss_debug.py +389 -0
- onnxruntime/quantization/qdq_quantizer.py +1477 -0
- onnxruntime/quantization/quant_utils.py +1051 -0
- onnxruntime/quantization/quantize.py +953 -0
- onnxruntime/quantization/registry.py +110 -0
- onnxruntime/quantization/shape_inference.py +204 -0
- onnxruntime/quantization/static_quantize_runner.py +256 -0
- onnxruntime/quantization/tensor_quant_overrides.py +520 -0
- onnxruntime/tools/__init__.py +10 -0
- onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
- onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
- onnxruntime/tools/file_utils.py +47 -0
- onnxruntime/tools/logger.py +11 -0
- onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
- onnxruntime/tools/mobile_helpers/__init__.py +0 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
- onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
- onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
- onnxruntime/tools/offline_tuning.py +169 -0
- onnxruntime/tools/onnx_model_utils.py +416 -0
- onnxruntime/tools/onnx_randomizer.py +85 -0
- onnxruntime/tools/onnxruntime_test.py +164 -0
- onnxruntime/tools/optimize_onnx_model.py +56 -0
- onnxruntime/tools/ort_format_model/__init__.py +27 -0
- onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
- onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
- onnxruntime/tools/ort_format_model/types.py +85 -0
- onnxruntime/tools/ort_format_model/utils.py +61 -0
- onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
- onnxruntime/tools/pytorch_export_helpers.py +131 -0
- onnxruntime/tools/qdq_helpers/__init__.py +0 -0
- onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
- onnxruntime/tools/qnn/add_trans_cast.py +292 -0
- onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
- onnxruntime/tools/qnn/preprocess.py +165 -0
- onnxruntime/tools/reduced_build_config_parser.py +203 -0
- onnxruntime/tools/remove_initializer_from_input.py +37 -0
- onnxruntime/tools/symbolic_shape_infer.py +3094 -0
- onnxruntime/tools/update_onnx_opset.py +31 -0
- onnxruntime/transformers/__init__.py +8 -0
- onnxruntime/transformers/affinity_helper.py +40 -0
- onnxruntime/transformers/benchmark.py +942 -0
- onnxruntime/transformers/benchmark_helper.py +643 -0
- onnxruntime/transformers/bert_perf_test.py +629 -0
- onnxruntime/transformers/bert_test_data.py +641 -0
- onnxruntime/transformers/compare_bert_results.py +256 -0
- onnxruntime/transformers/constants.py +47 -0
- onnxruntime/transformers/convert_generation.py +3605 -0
- onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
- onnxruntime/transformers/convert_to_packing_mode.py +385 -0
- onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
- onnxruntime/transformers/float16.py +501 -0
- onnxruntime/transformers/fusion_attention.py +1189 -0
- onnxruntime/transformers/fusion_attention_clip.py +340 -0
- onnxruntime/transformers/fusion_attention_sam2.py +533 -0
- onnxruntime/transformers/fusion_attention_unet.py +1307 -0
- onnxruntime/transformers/fusion_attention_vae.py +300 -0
- onnxruntime/transformers/fusion_bart_attention.py +435 -0
- onnxruntime/transformers/fusion_base.py +141 -0
- onnxruntime/transformers/fusion_bias_add.py +57 -0
- onnxruntime/transformers/fusion_biasgelu.py +66 -0
- onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
- onnxruntime/transformers/fusion_conformer_attention.py +222 -0
- onnxruntime/transformers/fusion_constant_fold.py +144 -0
- onnxruntime/transformers/fusion_embedlayer.py +810 -0
- onnxruntime/transformers/fusion_fastgelu.py +492 -0
- onnxruntime/transformers/fusion_gelu.py +258 -0
- onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
- onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
- onnxruntime/transformers/fusion_gpt_attention.py +546 -0
- onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
- onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
- onnxruntime/transformers/fusion_group_norm.py +180 -0
- onnxruntime/transformers/fusion_layernorm.py +489 -0
- onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
- onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
- onnxruntime/transformers/fusion_options.py +340 -0
- onnxruntime/transformers/fusion_qordered_attention.py +420 -0
- onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
- onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
- onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
- onnxruntime/transformers/fusion_quickgelu.py +74 -0
- onnxruntime/transformers/fusion_reshape.py +173 -0
- onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
- onnxruntime/transformers/fusion_shape.py +109 -0
- onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
- onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
- onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
- onnxruntime/transformers/fusion_transpose.py +167 -0
- onnxruntime/transformers/fusion_utils.py +321 -0
- onnxruntime/transformers/huggingface_models.py +74 -0
- onnxruntime/transformers/import_utils.py +20 -0
- onnxruntime/transformers/io_binding_helper.py +487 -0
- onnxruntime/transformers/large_model_exporter.py +395 -0
- onnxruntime/transformers/machine_info.py +230 -0
- onnxruntime/transformers/metrics.py +163 -0
- onnxruntime/transformers/models/bart/__init__.py +12 -0
- onnxruntime/transformers/models/bart/export.py +98 -0
- onnxruntime/transformers/models/bert/__init__.py +12 -0
- onnxruntime/transformers/models/bert/eval_squad.py +329 -0
- onnxruntime/transformers/models/gpt2/__init__.py +12 -0
- onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
- onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
- onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
- onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
- onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
- onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
- onnxruntime/transformers/models/llama/__init__.py +12 -0
- onnxruntime/transformers/models/llama/benchmark.py +700 -0
- onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
- onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
- onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
- onnxruntime/transformers/models/llama/dist_settings.py +57 -0
- onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
- onnxruntime/transformers/models/llama/llama_parity.py +343 -0
- onnxruntime/transformers/models/llama/llama_torch.py +47 -0
- onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
- onnxruntime/transformers/models/longformer/__init__.py +12 -0
- onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
- onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
- onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
- onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
- onnxruntime/transformers/models/phi2/__init__.py +12 -0
- onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
- onnxruntime/transformers/models/phi2/inference_example.py +414 -0
- onnxruntime/transformers/models/sam2/__init__.py +12 -0
- onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
- onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
- onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
- onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
- onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
- onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
- onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
- onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
- onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
- onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
- onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
- onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
- onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
- onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
- onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
- onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
- onnxruntime/transformers/models/t5/__init__.py +12 -0
- onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
- onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
- onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
- onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
- onnxruntime/transformers/models/t5/t5_helper.py +302 -0
- onnxruntime/transformers/models/whisper/__init__.py +12 -0
- onnxruntime/transformers/models/whisper/benchmark.py +585 -0
- onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
- onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
- onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
- onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
- onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
- onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
- onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
- onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
- onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
- onnxruntime/transformers/onnx_exporter.py +719 -0
- onnxruntime/transformers/onnx_model.py +1636 -0
- onnxruntime/transformers/onnx_model_bart.py +141 -0
- onnxruntime/transformers/onnx_model_bert.py +488 -0
- onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
- onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
- onnxruntime/transformers/onnx_model_clip.py +42 -0
- onnxruntime/transformers/onnx_model_conformer.py +32 -0
- onnxruntime/transformers/onnx_model_gpt2.py +101 -0
- onnxruntime/transformers/onnx_model_mmdit.py +112 -0
- onnxruntime/transformers/onnx_model_phi.py +929 -0
- onnxruntime/transformers/onnx_model_sam2.py +137 -0
- onnxruntime/transformers/onnx_model_t5.py +985 -0
- onnxruntime/transformers/onnx_model_tnlr.py +226 -0
- onnxruntime/transformers/onnx_model_unet.py +258 -0
- onnxruntime/transformers/onnx_model_vae.py +42 -0
- onnxruntime/transformers/onnx_utils.py +55 -0
- onnxruntime/transformers/optimizer.py +620 -0
- onnxruntime/transformers/past_helper.py +149 -0
- onnxruntime/transformers/profile_result_processor.py +358 -0
- onnxruntime/transformers/profiler.py +434 -0
- onnxruntime/transformers/quantize_helper.py +76 -0
- onnxruntime/transformers/shape_infer_helper.py +121 -0
- onnxruntime/transformers/shape_optimizer.py +400 -0
- onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
- onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
- onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
- onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
- onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
- onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,653 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
from abc import ABC, abstractmethod
|
|
7
|
+
|
|
8
|
+
import ort_flatbuffers_py.fbs as fbs
|
|
9
|
+
|
|
10
|
+
from .types import FbsTypeInfo, value_name_to_typestr
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def _create_op_key(domain: str, optype: str):
|
|
14
|
+
return f"{domain}:{optype}"
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def _ort_constant_for_domain(domain: str):
|
|
18
|
+
"""
|
|
19
|
+
Map a string domain value to the internal ONNX Runtime constant for that domain.
|
|
20
|
+
:param domain: Domain string to map.
|
|
21
|
+
:return: Internal ONNX Runtime constant
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
# constants are defined in <ORT root>/include/onnxruntime/core/graph/constants.h
|
|
25
|
+
# This list is limited to just the domains we have processors for
|
|
26
|
+
domain_to_constant_map = {"ai.onnx": "kOnnxDomain", "ai.onnx.ml": "kMLDomain", "com.microsoft": "kMSDomain"}
|
|
27
|
+
|
|
28
|
+
if domain not in domain_to_constant_map:
|
|
29
|
+
raise ValueError(f"Domain {domain} not found in map to ONNX Runtime constant. Please update map.")
|
|
30
|
+
|
|
31
|
+
return domain_to_constant_map[domain]
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def _reg_type_to_cpp_type(reg_type: str):
|
|
35
|
+
if reg_type == "string":
|
|
36
|
+
return "std::string"
|
|
37
|
+
return reg_type
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def _split_reg_types(reg_types_str: str):
|
|
41
|
+
"""
|
|
42
|
+
Split on underscores but append "_t" to the previous element.
|
|
43
|
+
"""
|
|
44
|
+
tokens = reg_types_str.split("_")
|
|
45
|
+
reg_types = []
|
|
46
|
+
for token in tokens:
|
|
47
|
+
if token == "t" and len(reg_types) > 0:
|
|
48
|
+
reg_types[-1] += "_t"
|
|
49
|
+
else:
|
|
50
|
+
reg_types += [token]
|
|
51
|
+
return reg_types
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class TypeUsageProcessor(ABC):
|
|
55
|
+
"""
|
|
56
|
+
Abstract base class for processors which implement operator specific logic to determine the type or types required.
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
def __init__(self, domain: str, optype: str):
|
|
60
|
+
self.domain = domain
|
|
61
|
+
self.optype = optype
|
|
62
|
+
self.name = _create_op_key(domain, optype)
|
|
63
|
+
|
|
64
|
+
@abstractmethod
|
|
65
|
+
def process_node(self, node: fbs.Node, value_name_to_typeinfo: dict):
|
|
66
|
+
pass
|
|
67
|
+
|
|
68
|
+
def is_typed_registration_needed(self, type_in_registration: str, globally_allowed_types: set[str] | None):
|
|
69
|
+
"""
|
|
70
|
+
Given the string from a kernel registration, determine if the registration is required or not.
|
|
71
|
+
:param type_in_registration: Type string from kernel registration
|
|
72
|
+
:param globally_allowed_types: Optional set of globally allowed types. If provided, these types take precedence
|
|
73
|
+
in determining the required types.
|
|
74
|
+
:return: True is required. False if not.
|
|
75
|
+
"""
|
|
76
|
+
# Not all operators have typed registrations, so this is optionally implemented by derived classes
|
|
77
|
+
raise RuntimeError(f"Did not expect processor for {self.name} to have typed registrations.")
|
|
78
|
+
|
|
79
|
+
def get_cpp_entry(self):
|
|
80
|
+
"""
|
|
81
|
+
Get the C++ code that specifies this operator's required types.
|
|
82
|
+
:return: List with any applicable C++ code for this operator's required types. One line per entry.
|
|
83
|
+
"""
|
|
84
|
+
# Not applicable for some ops, so return no lines by default.
|
|
85
|
+
return []
|
|
86
|
+
|
|
87
|
+
@abstractmethod
|
|
88
|
+
def to_config_entry(self):
|
|
89
|
+
"""
|
|
90
|
+
Generate a configuration file entry in JSON format with the required types for the operator.
|
|
91
|
+
:return: JSON string with required type information.
|
|
92
|
+
"""
|
|
93
|
+
|
|
94
|
+
@abstractmethod
|
|
95
|
+
def from_config_entry(self, entry: str):
|
|
96
|
+
"""
|
|
97
|
+
Re-create the types required from a configuration file entry created with to_config_entry.
|
|
98
|
+
NOTE: Any existing type information should be cleared prior to re-creating from a config file entry.
|
|
99
|
+
:param entry: Configuration file entry
|
|
100
|
+
"""
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
class DefaultTypeUsageProcessor(TypeUsageProcessor):
|
|
104
|
+
"""
|
|
105
|
+
Operator processor which tracks the types used for selected input/s and/or output/s.
|
|
106
|
+
"""
|
|
107
|
+
|
|
108
|
+
def __init__(
|
|
109
|
+
self,
|
|
110
|
+
domain: str,
|
|
111
|
+
optype: str,
|
|
112
|
+
inputs: [int] = [0], # noqa: B006
|
|
113
|
+
outputs: [int] = [], # noqa: B006
|
|
114
|
+
required_input_types: dict[int, set[str]] = {}, # noqa: B006
|
|
115
|
+
required_output_types: dict[int, set[str]] = {}, # noqa: B006
|
|
116
|
+
):
|
|
117
|
+
"""
|
|
118
|
+
Create DefaultTypeUsageProcessor. Types for one or more inputs and/or outputs can be tracked by the processor.
|
|
119
|
+
The default is to track the types required for input 0, as this is the most common use case in ONNX.
|
|
120
|
+
|
|
121
|
+
Required input and output types may be specified. These are only applicable to is_typed_registration_needed().
|
|
122
|
+
If a registration type matches a required type, the typed registration is needed.
|
|
123
|
+
There is a separate mechanism for specifying required types from C++ for kernels with untyped registration.
|
|
124
|
+
|
|
125
|
+
:param domain: Operator domain.
|
|
126
|
+
:param optype: Operator name.
|
|
127
|
+
:param inputs: Inputs to track. Zero based index. May be empty.
|
|
128
|
+
:param outputs: Outputs to track. Zero based index. May be empty.
|
|
129
|
+
:param required_input_types: Required input types. May be empty.
|
|
130
|
+
:param required_output_types: Required output types. May be empty.
|
|
131
|
+
"""
|
|
132
|
+
super().__init__(domain, optype)
|
|
133
|
+
self._input_types = {}
|
|
134
|
+
self._output_types = {}
|
|
135
|
+
|
|
136
|
+
for i in inputs:
|
|
137
|
+
self._input_types[i] = set()
|
|
138
|
+
|
|
139
|
+
for o in outputs:
|
|
140
|
+
self._output_types[o] = set()
|
|
141
|
+
|
|
142
|
+
if not inputs and not outputs:
|
|
143
|
+
raise ValueError("At least one input or output must be tracked")
|
|
144
|
+
|
|
145
|
+
self._required_input_types = required_input_types
|
|
146
|
+
self._required_output_types = required_output_types
|
|
147
|
+
|
|
148
|
+
def _is_type_enabled(self, reg_type, index, required_types, allowed_type_set):
|
|
149
|
+
cpp_type = _reg_type_to_cpp_type(reg_type)
|
|
150
|
+
return cpp_type in required_types.get(index, set()) or cpp_type in allowed_type_set
|
|
151
|
+
|
|
152
|
+
def is_input_type_enabled(self, reg_type, index, allowed_type_set=None):
|
|
153
|
+
"""Whether input type is enabled based on required and allowed types."""
|
|
154
|
+
if allowed_type_set is None:
|
|
155
|
+
allowed_type_set = self._input_types[index]
|
|
156
|
+
return self._is_type_enabled(reg_type, index, self._required_input_types, allowed_type_set)
|
|
157
|
+
|
|
158
|
+
def is_output_type_enabled(self, reg_type, index, allowed_type_set=None):
|
|
159
|
+
"""Whether output type is enabled based on required and allowed types."""
|
|
160
|
+
if allowed_type_set is None:
|
|
161
|
+
allowed_type_set = self._output_types[index]
|
|
162
|
+
return self._is_type_enabled(reg_type, index, self._required_output_types, allowed_type_set)
|
|
163
|
+
|
|
164
|
+
def process_node(self, node: fbs.Node, value_name_to_typeinfo: dict):
|
|
165
|
+
for i in self._input_types:
|
|
166
|
+
if i >= node.InputsLength():
|
|
167
|
+
# Some operators have fewer inputs in earlier versions where data that was as an attribute
|
|
168
|
+
# become an input in later versions to allow it to be dynamically provided. Allow for that.
|
|
169
|
+
# e.g. Slice-1 had attributes for the indices, and Slice-10 moved those to be inputs
|
|
170
|
+
# raise RuntimeError('Node has {} outputs. Tracker for {} incorrectly configured as it requires {}.'
|
|
171
|
+
# .format(node.OutputsLength(), self.name, o))
|
|
172
|
+
pass
|
|
173
|
+
else:
|
|
174
|
+
type_str = value_name_to_typestr(node.Inputs(i), value_name_to_typeinfo)
|
|
175
|
+
self._input_types[i].add(type_str)
|
|
176
|
+
|
|
177
|
+
for o in self._output_types:
|
|
178
|
+
# Don't know of any ops where the number of outputs changed across versions, so require a valid length
|
|
179
|
+
if o >= node.OutputsLength():
|
|
180
|
+
raise RuntimeError(
|
|
181
|
+
f"Node has {node.OutputsLength()} outputs. Tracker for {self.name} incorrectly configured as it requires {o}."
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
type_str = value_name_to_typestr(node.Outputs(o), value_name_to_typeinfo)
|
|
185
|
+
self._output_types[o].add(type_str)
|
|
186
|
+
|
|
187
|
+
def is_typed_registration_needed(self, type_in_registration: str, globally_allowed_types: set[str] | None):
|
|
188
|
+
if 0 not in self._input_types:
|
|
189
|
+
# currently all standard typed registrations are for input 0.
|
|
190
|
+
# custom registrations can be handled by operator specific processors (e.g. OneHotProcessor below).
|
|
191
|
+
raise RuntimeError(f"Expected typed registration to use type from input 0. Node:{self.name}")
|
|
192
|
+
|
|
193
|
+
return self.is_input_type_enabled(type_in_registration, 0, globally_allowed_types)
|
|
194
|
+
|
|
195
|
+
def get_cpp_entry(self):
|
|
196
|
+
entries = []
|
|
197
|
+
domain = _ort_constant_for_domain(self.domain)
|
|
198
|
+
for i in sorted(self._input_types.keys()):
|
|
199
|
+
if self._input_types[i]:
|
|
200
|
+
entries.append(
|
|
201
|
+
"ORT_SPECIFY_OP_KERNEL_ARG_ALLOWED_TYPES({}, {}, Input, {}, {});".format(
|
|
202
|
+
domain, self.optype, i, ", ".join(sorted(self._input_types[i]))
|
|
203
|
+
)
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
for o in sorted(self._output_types.keys()):
|
|
207
|
+
if self._output_types[o]:
|
|
208
|
+
entries.append(
|
|
209
|
+
"ORT_SPECIFY_OP_KERNEL_ARG_ALLOWED_TYPES({}, {}, Output, {}, {});".format(
|
|
210
|
+
domain, self.optype, o, ", ".join(sorted(self._output_types[o]))
|
|
211
|
+
)
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
return entries
|
|
215
|
+
|
|
216
|
+
def to_config_entry(self):
|
|
217
|
+
# convert the sets of types to lists so they can easily written out using the json model
|
|
218
|
+
aggregate_info = {"inputs": {}, "outputs": {}}
|
|
219
|
+
|
|
220
|
+
# filter out empty entries and sort the types
|
|
221
|
+
for i in sorted(self._input_types.keys()):
|
|
222
|
+
if self._input_types[i]:
|
|
223
|
+
aggregate_info["inputs"][i] = sorted(self._input_types[i])
|
|
224
|
+
|
|
225
|
+
for o in sorted(self._output_types.keys()):
|
|
226
|
+
if self._output_types[o]:
|
|
227
|
+
aggregate_info["outputs"][o] = sorted(self._output_types[o])
|
|
228
|
+
|
|
229
|
+
# remove any empty keys
|
|
230
|
+
if not aggregate_info["inputs"]:
|
|
231
|
+
aggregate_info.pop("inputs")
|
|
232
|
+
if not aggregate_info["outputs"]:
|
|
233
|
+
aggregate_info.pop("outputs")
|
|
234
|
+
|
|
235
|
+
entry = json.dumps(aggregate_info) if aggregate_info else None
|
|
236
|
+
return entry
|
|
237
|
+
|
|
238
|
+
def from_config_entry(self, entry: str):
|
|
239
|
+
self._input_types.clear()
|
|
240
|
+
self._output_types.clear()
|
|
241
|
+
|
|
242
|
+
aggregate_info = json.loads(entry)
|
|
243
|
+
if "inputs" in aggregate_info:
|
|
244
|
+
for i_str, values in aggregate_info["inputs"].items():
|
|
245
|
+
self._input_types[int(i_str)] = set(values)
|
|
246
|
+
|
|
247
|
+
if "outputs" in aggregate_info:
|
|
248
|
+
for o_str, values in aggregate_info["outputs"].items():
|
|
249
|
+
self._output_types[int(o_str)] = set(values)
|
|
250
|
+
|
|
251
|
+
|
|
252
|
+
class Input1TypedRegistrationProcessor(DefaultTypeUsageProcessor):
|
|
253
|
+
"""
|
|
254
|
+
Processor for operators where the second input type is used in a typed kernel registration.
|
|
255
|
+
"""
|
|
256
|
+
|
|
257
|
+
def __init__(self, domain: str, optype: str):
|
|
258
|
+
# init with tracking of input 1 only.
|
|
259
|
+
super().__init__(domain, optype, inputs=[1], outputs=[])
|
|
260
|
+
|
|
261
|
+
def is_typed_registration_needed(self, type_in_registration: str, globally_allowed_types: set[str] | None):
|
|
262
|
+
return self.is_input_type_enabled(type_in_registration, 1, globally_allowed_types)
|
|
263
|
+
|
|
264
|
+
|
|
265
|
+
class Output0TypedRegistrationProcessor(DefaultTypeUsageProcessor):
|
|
266
|
+
"""
|
|
267
|
+
Processor for operators where the first output type is used in a typed kernel registration.
|
|
268
|
+
"""
|
|
269
|
+
|
|
270
|
+
def __init__(self, domain: str, optype: str):
|
|
271
|
+
# init with tracking of output 0 only.
|
|
272
|
+
super().__init__(domain, optype, inputs=[], outputs=[0])
|
|
273
|
+
|
|
274
|
+
def is_typed_registration_needed(self, type_in_registration: str, globally_allowed_types: set[str] | None):
|
|
275
|
+
return self.is_output_type_enabled(type_in_registration, 0, globally_allowed_types)
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
class OneHotProcessor(TypeUsageProcessor):
|
|
279
|
+
"""
|
|
280
|
+
Processor for the OneHot operator, which requires custom logic as the type registration key is a concatenation of
|
|
281
|
+
the three types involved instead of a single type name.
|
|
282
|
+
"""
|
|
283
|
+
|
|
284
|
+
def __init__(self):
|
|
285
|
+
super().__init__("ai.onnx", "OneHot")
|
|
286
|
+
self._triples = set()
|
|
287
|
+
|
|
288
|
+
def process_node(self, node: fbs.Node, value_name_to_typeinfo: dict):
|
|
289
|
+
type0 = value_name_to_typestr(node.Inputs(0), value_name_to_typeinfo)
|
|
290
|
+
type1 = value_name_to_typestr(node.Inputs(1), value_name_to_typeinfo)
|
|
291
|
+
type2 = value_name_to_typestr(node.Inputs(2), value_name_to_typeinfo)
|
|
292
|
+
# types in kernel registration are ordered this way: input (T1), output (T3), depth (T2)
|
|
293
|
+
key = (type0, type2, type1)
|
|
294
|
+
self._triples.add(key)
|
|
295
|
+
|
|
296
|
+
def is_typed_registration_needed(self, type_in_registration: str, globally_allowed_types: set[str] | None):
|
|
297
|
+
# the OneHot registration involves a concatenation of the 3 types involved
|
|
298
|
+
reg_types = tuple([_reg_type_to_cpp_type(reg_type) for reg_type in _split_reg_types(type_in_registration)])
|
|
299
|
+
if globally_allowed_types is not None:
|
|
300
|
+
return all(reg_type in globally_allowed_types for reg_type in reg_types)
|
|
301
|
+
else:
|
|
302
|
+
return reg_types in self._triples
|
|
303
|
+
|
|
304
|
+
def to_config_entry(self):
|
|
305
|
+
if not self._triples:
|
|
306
|
+
return None
|
|
307
|
+
|
|
308
|
+
aggregate_info = {"custom": sorted(self._triples)}
|
|
309
|
+
entry = json.dumps(aggregate_info)
|
|
310
|
+
return entry
|
|
311
|
+
|
|
312
|
+
def from_config_entry(self, entry: str):
|
|
313
|
+
self._triples.clear()
|
|
314
|
+
aggregate_info = json.loads(entry)
|
|
315
|
+
if "custom" in aggregate_info:
|
|
316
|
+
self._triples = {tuple(triple) for triple in aggregate_info["custom"]}
|
|
317
|
+
|
|
318
|
+
|
|
319
|
+
def _create_operator_type_usage_processors():
|
|
320
|
+
"""
|
|
321
|
+
Create a set of processors that determine the required types for all enabled operators.
|
|
322
|
+
:return: Dictionary of operator key to processor. Key is 'domain:operator (e.g. ai.onnx:Cast)'.
|
|
323
|
+
"""
|
|
324
|
+
operator_processors = {}
|
|
325
|
+
|
|
326
|
+
def add(processor):
|
|
327
|
+
if processor.name in operator_processors:
|
|
328
|
+
raise RuntimeError("Duplicate processor for " + processor.name)
|
|
329
|
+
|
|
330
|
+
operator_processors[processor.name] = processor
|
|
331
|
+
|
|
332
|
+
# Starting with ops from:
|
|
333
|
+
# - Priority 1P models
|
|
334
|
+
# - Mobilenet + SSD Mobilenet + MobileBert
|
|
335
|
+
# - some known large kernels
|
|
336
|
+
#
|
|
337
|
+
# Ops we are ignoring currently so as not to produce meaningless/unused output:
|
|
338
|
+
# - Implementation is type agnostic:
|
|
339
|
+
# ai.onnx: If, Loop, Reshape, Scan, Shape, Squeeze, Tile, Unsqueeze
|
|
340
|
+
# com.microsoft: DynamicQuantizeMatMul, MatMulIntegerToFloat
|
|
341
|
+
# - Only one type supported in the ORT implementation:
|
|
342
|
+
# ai.onnx: NonMaxSuppression
|
|
343
|
+
# com.microsoft: FusedConv, FusedGemm, FusedMatMul
|
|
344
|
+
# - Implementation does not have any significant type specific code:
|
|
345
|
+
# ai.onnx: Concat, Flatten, Not, Reshape, Shape, Squeeze, Unsqueeze
|
|
346
|
+
#
|
|
347
|
+
default_processor_onnx_ops = [
|
|
348
|
+
"Abs",
|
|
349
|
+
"ArgMax",
|
|
350
|
+
"ArgMin",
|
|
351
|
+
"AveragePool",
|
|
352
|
+
"BatchNormalization",
|
|
353
|
+
"BitShift",
|
|
354
|
+
"Ceil",
|
|
355
|
+
"Clip",
|
|
356
|
+
"Conv",
|
|
357
|
+
"CumSum",
|
|
358
|
+
"Exp",
|
|
359
|
+
"Expand",
|
|
360
|
+
"Floor",
|
|
361
|
+
"Gemm",
|
|
362
|
+
"IsNaN",
|
|
363
|
+
"Log",
|
|
364
|
+
"LogSoftmax",
|
|
365
|
+
"LpNormalization",
|
|
366
|
+
"MatMul",
|
|
367
|
+
"Max",
|
|
368
|
+
"MaxPool",
|
|
369
|
+
"Mean",
|
|
370
|
+
"Min",
|
|
371
|
+
"NonZero",
|
|
372
|
+
"Pad",
|
|
373
|
+
"QLinearConv",
|
|
374
|
+
"QLinearMatMul",
|
|
375
|
+
"Range",
|
|
376
|
+
"Reciprocal",
|
|
377
|
+
"ReduceL1",
|
|
378
|
+
"ReduceL2",
|
|
379
|
+
"ReduceLogSum",
|
|
380
|
+
"ReduceLogSumExp",
|
|
381
|
+
"ReduceMax",
|
|
382
|
+
"ReduceMean",
|
|
383
|
+
"ReduceMin",
|
|
384
|
+
"ReduceProd",
|
|
385
|
+
"ReduceSum",
|
|
386
|
+
"ReduceSumSquare",
|
|
387
|
+
"Relu",
|
|
388
|
+
"Resize",
|
|
389
|
+
"ReverseSequence",
|
|
390
|
+
"RoiAlign",
|
|
391
|
+
"Round",
|
|
392
|
+
"Scatter",
|
|
393
|
+
"ScatterElements",
|
|
394
|
+
"ScatterND",
|
|
395
|
+
"Shrink",
|
|
396
|
+
"Sigmoid",
|
|
397
|
+
"Sign",
|
|
398
|
+
"Sin",
|
|
399
|
+
"Softmax",
|
|
400
|
+
"Split",
|
|
401
|
+
"SplitToSequence",
|
|
402
|
+
"Sqrt",
|
|
403
|
+
"Sum",
|
|
404
|
+
"Tanh",
|
|
405
|
+
"TopK",
|
|
406
|
+
"Transpose",
|
|
407
|
+
"Unique",
|
|
408
|
+
]
|
|
409
|
+
|
|
410
|
+
# ops that are used to manipulate shapes or indices so require int32_t and int64_t to be available
|
|
411
|
+
default_processor_onnx_ops_requiring_ints_for_input_0 = [
|
|
412
|
+
"Add",
|
|
413
|
+
"Concat",
|
|
414
|
+
"Div",
|
|
415
|
+
"Equal",
|
|
416
|
+
"Greater",
|
|
417
|
+
"Less",
|
|
418
|
+
"Mul",
|
|
419
|
+
"Neg", # used in tflite TransposeConv conversion
|
|
420
|
+
"Sub",
|
|
421
|
+
]
|
|
422
|
+
|
|
423
|
+
# NOTE: QLinearConv has ONNX and internal implementations
|
|
424
|
+
internal_ops = ["QLinearAdd", "QLinearMul", "QLinearConv"]
|
|
425
|
+
|
|
426
|
+
# TODO - review and add ML ops as needed
|
|
427
|
+
# ML Op notes.
|
|
428
|
+
# CastMap: Switch on value type of input map type, and output type
|
|
429
|
+
# DictVectorizer: Templatized on key+value of input so need to handle like OneHot with custom processor
|
|
430
|
+
# LabelEncoder: Implementation switches on input and output types (only supports string and int64 in T1 and T2)
|
|
431
|
+
# LinearClassifier: Internal switch on input type and also switch on output type
|
|
432
|
+
# SVMClassifier: ditto
|
|
433
|
+
# TreeEnsembleClassifier: Templatized on input type and also switch on output type
|
|
434
|
+
# ZipMap: Switch on output type (derived from attributes)
|
|
435
|
+
default_processor_onnxml_ops = []
|
|
436
|
+
|
|
437
|
+
[add(DefaultTypeUsageProcessor("ai.onnx", op)) for op in default_processor_onnx_ops]
|
|
438
|
+
[
|
|
439
|
+
add(DefaultTypeUsageProcessor("ai.onnx", op, required_input_types={0: {"int32_t", "int64_t"}}))
|
|
440
|
+
for op in default_processor_onnx_ops_requiring_ints_for_input_0
|
|
441
|
+
]
|
|
442
|
+
[add(DefaultTypeUsageProcessor("ai.onnx.ml", op)) for op in default_processor_onnxml_ops]
|
|
443
|
+
[add(DefaultTypeUsageProcessor("com.microsoft", op)) for op in internal_ops]
|
|
444
|
+
|
|
445
|
+
#
|
|
446
|
+
# Operators that require custom handling
|
|
447
|
+
#
|
|
448
|
+
|
|
449
|
+
# Cast switches on types of input 0 and output 0
|
|
450
|
+
add(DefaultTypeUsageProcessor("ai.onnx", "Cast", inputs=[0], outputs=[0]))
|
|
451
|
+
|
|
452
|
+
# Operators that switch on the type of input 0 and 1
|
|
453
|
+
add(DefaultTypeUsageProcessor("ai.onnx", "Gather", inputs=[0, 1]))
|
|
454
|
+
add(DefaultTypeUsageProcessor("ai.onnx", "GatherElements", inputs=[0, 1]))
|
|
455
|
+
add(DefaultTypeUsageProcessor("ai.onnx", "Pow", inputs=[0, 1]))
|
|
456
|
+
add(DefaultTypeUsageProcessor("ai.onnx", "Slice", inputs=[0, 1]))
|
|
457
|
+
|
|
458
|
+
# Operators that switch on output type
|
|
459
|
+
add(DefaultTypeUsageProcessor("ai.onnx", "ConstantOfShape", inputs=[], outputs=[0]))
|
|
460
|
+
|
|
461
|
+
# Random generator ops produce new data so we track the output type
|
|
462
|
+
onnx_random_ops = ["RandomNormal", "RandomNormalLike", "RandomUniform", "RandomUniformLike", "Multinomial"]
|
|
463
|
+
[add(DefaultTypeUsageProcessor("ai.onnx", op, inputs=[], outputs=[0])) for op in onnx_random_ops]
|
|
464
|
+
|
|
465
|
+
# Where always has a boolean first input so track the second input type for typed registration
|
|
466
|
+
add(Input1TypedRegistrationProcessor("ai.onnx", "Where"))
|
|
467
|
+
|
|
468
|
+
# we only support 'float' as input for [Dynamic]QuantizeLinear so just track the output type
|
|
469
|
+
# as that's what is used in the typed registration
|
|
470
|
+
add(Output0TypedRegistrationProcessor("ai.onnx", "QuantizeLinear"))
|
|
471
|
+
add(Output0TypedRegistrationProcessor("ai.onnx", "DynamicQuantizeLinear"))
|
|
472
|
+
|
|
473
|
+
# make sure all the dequantize types are enabled. we use int32_t for parts of GEMM and Conv so just
|
|
474
|
+
# enabling int8 and uint8 is not enough.
|
|
475
|
+
# TODO: Only apply required types to the global type list and ignore if it's model based per-op type reduction
|
|
476
|
+
add(
|
|
477
|
+
DefaultTypeUsageProcessor(
|
|
478
|
+
"ai.onnx", "DequantizeLinear", inputs=[0], required_input_types={0: {"int8_t", "uint8_t", "int32_t"}}
|
|
479
|
+
)
|
|
480
|
+
)
|
|
481
|
+
|
|
482
|
+
# OneHot concatenates type strings into a triple in the typed registration
|
|
483
|
+
# e.g. float_int64_t_int64_t
|
|
484
|
+
add(OneHotProcessor())
|
|
485
|
+
|
|
486
|
+
return operator_processors
|
|
487
|
+
|
|
488
|
+
|
|
489
|
+
class OpTypeImplFilterInterface(ABC):
|
|
490
|
+
"""
|
|
491
|
+
Class that filters operator implementations based on type.
|
|
492
|
+
"""
|
|
493
|
+
|
|
494
|
+
@abstractmethod
|
|
495
|
+
def is_typed_registration_needed(self, domain: str, optype: str, type_registration_str: str):
|
|
496
|
+
"""
|
|
497
|
+
Given the string from a kernel registration, determine if the registration is required or not.
|
|
498
|
+
:param domain: Operator domain.
|
|
499
|
+
:param optype: Operator type.
|
|
500
|
+
:param type_registration_str: Type string from kernel registration
|
|
501
|
+
:return: True is required. False if not.
|
|
502
|
+
"""
|
|
503
|
+
|
|
504
|
+
@abstractmethod
|
|
505
|
+
def get_cpp_entries(self):
|
|
506
|
+
"""
|
|
507
|
+
Get the C++ code that specifies the operator types to enable.
|
|
508
|
+
:return: List of strings. One line of C++ code per entry.
|
|
509
|
+
"""
|
|
510
|
+
|
|
511
|
+
|
|
512
|
+
class OperatorTypeUsageManager:
|
|
513
|
+
"""
|
|
514
|
+
Class to manage the operator type usage processors.
|
|
515
|
+
TODO: Currently the type tracking is not specific to a version of the operator.
|
|
516
|
+
It's unclear how/where version specific logic could/should be added, and it would add significant complexity
|
|
517
|
+
to track types on a per-version basis. Not clear there's enough benefit from doing so either.
|
|
518
|
+
"""
|
|
519
|
+
|
|
520
|
+
def __init__(self):
|
|
521
|
+
self._all_operator_processors = _create_operator_type_usage_processors() # all possible processors
|
|
522
|
+
self._operator_processors = {} # processors we have actually used so we can limit output to be meaningful
|
|
523
|
+
|
|
524
|
+
def _get_op_processor(self, key):
|
|
525
|
+
"Add the processor to _operator_processors as it is about to be used."
|
|
526
|
+
processor = None
|
|
527
|
+
if key in self._all_operator_processors:
|
|
528
|
+
if key not in self._operator_processors:
|
|
529
|
+
self._operator_processors[key] = self._all_operator_processors[key]
|
|
530
|
+
|
|
531
|
+
processor = self._operator_processors[key]
|
|
532
|
+
|
|
533
|
+
return processor
|
|
534
|
+
|
|
535
|
+
def process_node(self, node: fbs.Node, value_name_to_typeinfo: dict):
|
|
536
|
+
"""
|
|
537
|
+
Process a Node and record info on the types used.
|
|
538
|
+
:param node: Node from ORT format model
|
|
539
|
+
:param value_name_to_typeinfo: Map of value names to TypeInfo instances
|
|
540
|
+
"""
|
|
541
|
+
optype = node.OpType().decode()
|
|
542
|
+
domain = node.Domain().decode() or "ai.onnx" # empty domain defaults to ai.onnx
|
|
543
|
+
|
|
544
|
+
key = _create_op_key(domain, optype)
|
|
545
|
+
op_processor = self._get_op_processor(key)
|
|
546
|
+
if op_processor:
|
|
547
|
+
op_processor.process_node(node, value_name_to_typeinfo)
|
|
548
|
+
|
|
549
|
+
def get_config_entry(self, domain: str, optype: str):
|
|
550
|
+
"""
|
|
551
|
+
Get the config entry specifying the types for this operator.
|
|
552
|
+
:param domain: Operator domain.
|
|
553
|
+
:param optype: Operator type.
|
|
554
|
+
:return: JSON string with type info if available, else None
|
|
555
|
+
"""
|
|
556
|
+
key = _create_op_key(domain, optype)
|
|
557
|
+
config_str = None
|
|
558
|
+
if key in self._operator_processors:
|
|
559
|
+
config_str = self._operator_processors[key].to_config_entry()
|
|
560
|
+
|
|
561
|
+
return config_str
|
|
562
|
+
|
|
563
|
+
def restore_from_config_entry(self, domain: str, optype: str, config_entry: str):
|
|
564
|
+
"""
|
|
565
|
+
Restore the per-operator type information from a configuration file entry.
|
|
566
|
+
:param domain: Operator domain.
|
|
567
|
+
:param optype: Operator type.
|
|
568
|
+
:param config_entry: JSON string with type info as created by get_config_entry
|
|
569
|
+
"""
|
|
570
|
+
key = _create_op_key(domain, optype)
|
|
571
|
+
op_processor = self._get_op_processor(key)
|
|
572
|
+
if op_processor:
|
|
573
|
+
op_processor.from_config_entry(config_entry)
|
|
574
|
+
|
|
575
|
+
def debug_dump(self):
|
|
576
|
+
print("C++ code that will be emitted:")
|
|
577
|
+
[print(cpp_line) for cpp_line in self.get_cpp_entries()]
|
|
578
|
+
|
|
579
|
+
print("Config file type information that will be returned by get_config_entry:")
|
|
580
|
+
for key in sorted(self._operator_processors.keys()):
|
|
581
|
+
entry = self._operator_processors[key].to_config_entry()
|
|
582
|
+
if entry:
|
|
583
|
+
print(f"{key} -> {entry}")
|
|
584
|
+
|
|
585
|
+
# roundtrip test to validate that we can initialize the processor from the entry and get the
|
|
586
|
+
# same values back
|
|
587
|
+
self._operator_processors[key].from_config_entry(entry)
|
|
588
|
+
assert entry == self._operator_processors[key].to_config_entry()
|
|
589
|
+
|
|
590
|
+
class _OpTypeImplFilter(OpTypeImplFilterInterface):
|
|
591
|
+
def __init__(self, manager):
|
|
592
|
+
self._manager = manager
|
|
593
|
+
|
|
594
|
+
def is_typed_registration_needed(self, domain: str, optype: str, type_registration_str: str):
|
|
595
|
+
needed = True # we keep the registration unless the per-operator processor says not to
|
|
596
|
+
key = _create_op_key(domain, optype)
|
|
597
|
+
if key in self._manager._operator_processors:
|
|
598
|
+
needed = self._manager._operator_processors[key].is_typed_registration_needed(
|
|
599
|
+
type_in_registration=type_registration_str, globally_allowed_types=None
|
|
600
|
+
)
|
|
601
|
+
|
|
602
|
+
return needed
|
|
603
|
+
|
|
604
|
+
def get_cpp_entries(self):
|
|
605
|
+
entries = []
|
|
606
|
+
for key in sorted(self._manager._operator_processors.keys()):
|
|
607
|
+
entries.extend(self._manager._operator_processors[key].get_cpp_entry())
|
|
608
|
+
|
|
609
|
+
return entries
|
|
610
|
+
|
|
611
|
+
def make_op_type_impl_filter(self):
|
|
612
|
+
"""
|
|
613
|
+
Creates an OpTypeImplFilterInterface instance from this manager.
|
|
614
|
+
Filtering uses the manager's operator type usage processor state.
|
|
615
|
+
"""
|
|
616
|
+
return OperatorTypeUsageManager._OpTypeImplFilter(self)
|
|
617
|
+
|
|
618
|
+
|
|
619
|
+
class GloballyAllowedTypesOpTypeImplFilter(OpTypeImplFilterInterface):
|
|
620
|
+
"""
|
|
621
|
+
Operator implementation filter which uses globally allowed types.
|
|
622
|
+
"""
|
|
623
|
+
|
|
624
|
+
_valid_allowed_types = set(FbsTypeInfo.tensordatatype_to_string.values()) # noqa: RUF012
|
|
625
|
+
|
|
626
|
+
def __init__(self, globally_allowed_types: set[str]):
|
|
627
|
+
self._operator_processors = _create_operator_type_usage_processors()
|
|
628
|
+
|
|
629
|
+
if not globally_allowed_types.issubset(self._valid_allowed_types):
|
|
630
|
+
raise ValueError(
|
|
631
|
+
f"Globally allowed types must all be valid. Invalid types: {sorted(globally_allowed_types - self._valid_allowed_types)}"
|
|
632
|
+
)
|
|
633
|
+
|
|
634
|
+
self._globally_allowed_types = globally_allowed_types
|
|
635
|
+
|
|
636
|
+
def is_typed_registration_needed(self, domain: str, optype: str, type_registration_str: str):
|
|
637
|
+
key = _create_op_key(domain, optype)
|
|
638
|
+
if key in self._operator_processors:
|
|
639
|
+
needed = self._operator_processors[key].is_typed_registration_needed(
|
|
640
|
+
type_in_registration=type_registration_str, globally_allowed_types=self._globally_allowed_types
|
|
641
|
+
)
|
|
642
|
+
else:
|
|
643
|
+
needed = _reg_type_to_cpp_type(type_registration_str) in self._globally_allowed_types
|
|
644
|
+
|
|
645
|
+
return needed
|
|
646
|
+
|
|
647
|
+
def get_cpp_entries(self):
|
|
648
|
+
return [
|
|
649
|
+
"ORT_SPECIFY_OP_KERNEL_GLOBAL_ALLOWED_TYPES({});".format(", ".join(sorted(self._globally_allowed_types)))
|
|
650
|
+
]
|
|
651
|
+
|
|
652
|
+
def global_type_list(self):
|
|
653
|
+
return self._globally_allowed_types
|
|
File without changes
|