onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxruntime/LICENSE +21 -0
- onnxruntime/Privacy.md +21 -0
- onnxruntime/ThirdPartyNotices.txt +6121 -0
- onnxruntime/__init__.py +418 -0
- onnxruntime/backend/__init__.py +6 -0
- onnxruntime/backend/backend.py +175 -0
- onnxruntime/backend/backend_rep.py +52 -0
- onnxruntime/capi/DirectML.dll +0 -0
- onnxruntime/capi/__init__.py +4 -0
- onnxruntime/capi/_ld_preload.py +7 -0
- onnxruntime/capi/_pybind_state.py +33 -0
- onnxruntime/capi/build_and_package_info.py +2 -0
- onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
- onnxruntime/capi/onnxruntime.dll +0 -0
- onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
- onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
- onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
- onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
- onnxruntime/capi/onnxruntime_validation.py +154 -0
- onnxruntime/capi/version_info.py +2 -0
- onnxruntime/datasets/__init__.py +18 -0
- onnxruntime/datasets/logreg_iris.onnx +0 -0
- onnxruntime/datasets/mul_1.onnx +0 -0
- onnxruntime/datasets/sigmoid.onnx +13 -0
- onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
- onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
- onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
- onnxruntime/quantization/__init__.py +19 -0
- onnxruntime/quantization/base_quantizer.py +529 -0
- onnxruntime/quantization/calibrate.py +1267 -0
- onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
- onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
- onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
- onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
- onnxruntime/quantization/fusions/__init__.py +4 -0
- onnxruntime/quantization/fusions/fusion.py +311 -0
- onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
- onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
- onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
- onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
- onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
- onnxruntime/quantization/neural_compressor/__init__.py +1 -0
- onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
- onnxruntime/quantization/neural_compressor/util.py +80 -0
- onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
- onnxruntime/quantization/onnx_model.py +600 -0
- onnxruntime/quantization/onnx_quantizer.py +1163 -0
- onnxruntime/quantization/operators/__init__.py +2 -0
- onnxruntime/quantization/operators/activation.py +119 -0
- onnxruntime/quantization/operators/argmax.py +18 -0
- onnxruntime/quantization/operators/attention.py +73 -0
- onnxruntime/quantization/operators/base_operator.py +26 -0
- onnxruntime/quantization/operators/binary_op.py +72 -0
- onnxruntime/quantization/operators/concat.py +62 -0
- onnxruntime/quantization/operators/conv.py +260 -0
- onnxruntime/quantization/operators/direct_q8.py +78 -0
- onnxruntime/quantization/operators/embed_layernorm.py +121 -0
- onnxruntime/quantization/operators/gather.py +64 -0
- onnxruntime/quantization/operators/gavgpool.py +62 -0
- onnxruntime/quantization/operators/gemm.py +172 -0
- onnxruntime/quantization/operators/lstm.py +121 -0
- onnxruntime/quantization/operators/matmul.py +231 -0
- onnxruntime/quantization/operators/maxpool.py +34 -0
- onnxruntime/quantization/operators/norm.py +40 -0
- onnxruntime/quantization/operators/pad.py +172 -0
- onnxruntime/quantization/operators/pooling.py +67 -0
- onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
- onnxruntime/quantization/operators/resize.py +34 -0
- onnxruntime/quantization/operators/softmax.py +74 -0
- onnxruntime/quantization/operators/split.py +63 -0
- onnxruntime/quantization/operators/where.py +87 -0
- onnxruntime/quantization/preprocess.py +141 -0
- onnxruntime/quantization/qdq_loss_debug.py +389 -0
- onnxruntime/quantization/qdq_quantizer.py +1477 -0
- onnxruntime/quantization/quant_utils.py +1051 -0
- onnxruntime/quantization/quantize.py +953 -0
- onnxruntime/quantization/registry.py +110 -0
- onnxruntime/quantization/shape_inference.py +204 -0
- onnxruntime/quantization/static_quantize_runner.py +256 -0
- onnxruntime/quantization/tensor_quant_overrides.py +520 -0
- onnxruntime/tools/__init__.py +10 -0
- onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
- onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
- onnxruntime/tools/file_utils.py +47 -0
- onnxruntime/tools/logger.py +11 -0
- onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
- onnxruntime/tools/mobile_helpers/__init__.py +0 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
- onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
- onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
- onnxruntime/tools/offline_tuning.py +169 -0
- onnxruntime/tools/onnx_model_utils.py +416 -0
- onnxruntime/tools/onnx_randomizer.py +85 -0
- onnxruntime/tools/onnxruntime_test.py +164 -0
- onnxruntime/tools/optimize_onnx_model.py +56 -0
- onnxruntime/tools/ort_format_model/__init__.py +27 -0
- onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
- onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
- onnxruntime/tools/ort_format_model/types.py +85 -0
- onnxruntime/tools/ort_format_model/utils.py +61 -0
- onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
- onnxruntime/tools/pytorch_export_helpers.py +131 -0
- onnxruntime/tools/qdq_helpers/__init__.py +0 -0
- onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
- onnxruntime/tools/qnn/add_trans_cast.py +292 -0
- onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
- onnxruntime/tools/qnn/preprocess.py +165 -0
- onnxruntime/tools/reduced_build_config_parser.py +203 -0
- onnxruntime/tools/remove_initializer_from_input.py +37 -0
- onnxruntime/tools/symbolic_shape_infer.py +3094 -0
- onnxruntime/tools/update_onnx_opset.py +31 -0
- onnxruntime/transformers/__init__.py +8 -0
- onnxruntime/transformers/affinity_helper.py +40 -0
- onnxruntime/transformers/benchmark.py +942 -0
- onnxruntime/transformers/benchmark_helper.py +643 -0
- onnxruntime/transformers/bert_perf_test.py +629 -0
- onnxruntime/transformers/bert_test_data.py +641 -0
- onnxruntime/transformers/compare_bert_results.py +256 -0
- onnxruntime/transformers/constants.py +47 -0
- onnxruntime/transformers/convert_generation.py +3605 -0
- onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
- onnxruntime/transformers/convert_to_packing_mode.py +385 -0
- onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
- onnxruntime/transformers/float16.py +501 -0
- onnxruntime/transformers/fusion_attention.py +1189 -0
- onnxruntime/transformers/fusion_attention_clip.py +340 -0
- onnxruntime/transformers/fusion_attention_sam2.py +533 -0
- onnxruntime/transformers/fusion_attention_unet.py +1307 -0
- onnxruntime/transformers/fusion_attention_vae.py +300 -0
- onnxruntime/transformers/fusion_bart_attention.py +435 -0
- onnxruntime/transformers/fusion_base.py +141 -0
- onnxruntime/transformers/fusion_bias_add.py +57 -0
- onnxruntime/transformers/fusion_biasgelu.py +66 -0
- onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
- onnxruntime/transformers/fusion_conformer_attention.py +222 -0
- onnxruntime/transformers/fusion_constant_fold.py +144 -0
- onnxruntime/transformers/fusion_embedlayer.py +810 -0
- onnxruntime/transformers/fusion_fastgelu.py +492 -0
- onnxruntime/transformers/fusion_gelu.py +258 -0
- onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
- onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
- onnxruntime/transformers/fusion_gpt_attention.py +546 -0
- onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
- onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
- onnxruntime/transformers/fusion_group_norm.py +180 -0
- onnxruntime/transformers/fusion_layernorm.py +489 -0
- onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
- onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
- onnxruntime/transformers/fusion_options.py +340 -0
- onnxruntime/transformers/fusion_qordered_attention.py +420 -0
- onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
- onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
- onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
- onnxruntime/transformers/fusion_quickgelu.py +74 -0
- onnxruntime/transformers/fusion_reshape.py +173 -0
- onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
- onnxruntime/transformers/fusion_shape.py +109 -0
- onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
- onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
- onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
- onnxruntime/transformers/fusion_transpose.py +167 -0
- onnxruntime/transformers/fusion_utils.py +321 -0
- onnxruntime/transformers/huggingface_models.py +74 -0
- onnxruntime/transformers/import_utils.py +20 -0
- onnxruntime/transformers/io_binding_helper.py +487 -0
- onnxruntime/transformers/large_model_exporter.py +395 -0
- onnxruntime/transformers/machine_info.py +230 -0
- onnxruntime/transformers/metrics.py +163 -0
- onnxruntime/transformers/models/bart/__init__.py +12 -0
- onnxruntime/transformers/models/bart/export.py +98 -0
- onnxruntime/transformers/models/bert/__init__.py +12 -0
- onnxruntime/transformers/models/bert/eval_squad.py +329 -0
- onnxruntime/transformers/models/gpt2/__init__.py +12 -0
- onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
- onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
- onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
- onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
- onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
- onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
- onnxruntime/transformers/models/llama/__init__.py +12 -0
- onnxruntime/transformers/models/llama/benchmark.py +700 -0
- onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
- onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
- onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
- onnxruntime/transformers/models/llama/dist_settings.py +57 -0
- onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
- onnxruntime/transformers/models/llama/llama_parity.py +343 -0
- onnxruntime/transformers/models/llama/llama_torch.py +47 -0
- onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
- onnxruntime/transformers/models/longformer/__init__.py +12 -0
- onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
- onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
- onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
- onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
- onnxruntime/transformers/models/phi2/__init__.py +12 -0
- onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
- onnxruntime/transformers/models/phi2/inference_example.py +414 -0
- onnxruntime/transformers/models/sam2/__init__.py +12 -0
- onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
- onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
- onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
- onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
- onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
- onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
- onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
- onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
- onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
- onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
- onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
- onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
- onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
- onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
- onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
- onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
- onnxruntime/transformers/models/t5/__init__.py +12 -0
- onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
- onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
- onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
- onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
- onnxruntime/transformers/models/t5/t5_helper.py +302 -0
- onnxruntime/transformers/models/whisper/__init__.py +12 -0
- onnxruntime/transformers/models/whisper/benchmark.py +585 -0
- onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
- onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
- onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
- onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
- onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
- onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
- onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
- onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
- onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
- onnxruntime/transformers/onnx_exporter.py +719 -0
- onnxruntime/transformers/onnx_model.py +1636 -0
- onnxruntime/transformers/onnx_model_bart.py +141 -0
- onnxruntime/transformers/onnx_model_bert.py +488 -0
- onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
- onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
- onnxruntime/transformers/onnx_model_clip.py +42 -0
- onnxruntime/transformers/onnx_model_conformer.py +32 -0
- onnxruntime/transformers/onnx_model_gpt2.py +101 -0
- onnxruntime/transformers/onnx_model_mmdit.py +112 -0
- onnxruntime/transformers/onnx_model_phi.py +929 -0
- onnxruntime/transformers/onnx_model_sam2.py +137 -0
- onnxruntime/transformers/onnx_model_t5.py +985 -0
- onnxruntime/transformers/onnx_model_tnlr.py +226 -0
- onnxruntime/transformers/onnx_model_unet.py +258 -0
- onnxruntime/transformers/onnx_model_vae.py +42 -0
- onnxruntime/transformers/onnx_utils.py +55 -0
- onnxruntime/transformers/optimizer.py +620 -0
- onnxruntime/transformers/past_helper.py +149 -0
- onnxruntime/transformers/profile_result_processor.py +358 -0
- onnxruntime/transformers/profiler.py +434 -0
- onnxruntime/transformers/quantize_helper.py +76 -0
- onnxruntime/transformers/shape_infer_helper.py +121 -0
- onnxruntime/transformers/shape_optimizer.py +400 -0
- onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
- onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
- onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
- onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
- onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
- onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
|
Binary file
|
|
Binary file
|
|
@@ -0,0 +1,154 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License.
|
|
4
|
+
# --------------------------------------------------------------------------
|
|
5
|
+
"""
|
|
6
|
+
Check OS requirements for ONNX Runtime Python Bindings.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import linecache
|
|
10
|
+
import platform
|
|
11
|
+
import warnings
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def check_distro_info():
|
|
15
|
+
__my_distro__ = ""
|
|
16
|
+
__my_distro_ver__ = ""
|
|
17
|
+
__my_system__ = platform.system().lower()
|
|
18
|
+
|
|
19
|
+
__OS_RELEASE_FILE__ = "/etc/os-release" # noqa: N806
|
|
20
|
+
__LSB_RELEASE_FILE__ = "/etc/lsb-release" # noqa: N806
|
|
21
|
+
|
|
22
|
+
if __my_system__ == "windows":
|
|
23
|
+
__my_distro__ = __my_system__
|
|
24
|
+
__my_distro_ver__ = platform.release().lower()
|
|
25
|
+
|
|
26
|
+
if __my_distro_ver__ not in ["10", "11", "2016server", "2019server", "2022server", "2025server"]:
|
|
27
|
+
warnings.warn(
|
|
28
|
+
f"Unsupported Windows version ({__my_distro_ver__}). ONNX Runtime supports Windows 10 and above, or Windows Server 2016 and above."
|
|
29
|
+
)
|
|
30
|
+
elif __my_system__ == "linux":
|
|
31
|
+
"""Although the 'platform' python module for getting Distro information works well on standard OS images
|
|
32
|
+
running on real hardware, it is not accurate when running on Azure VMs, Git Bash, Cygwin, etc.
|
|
33
|
+
The returned values for release and version are unpredictable for virtualized or emulated environments.
|
|
34
|
+
/etc/os-release and /etc/lsb_release files, on the other hand, are guaranteed to exist and have standard values
|
|
35
|
+
in all OSes supported by onnxruntime. The former is the current standard file to check OS info and the latter
|
|
36
|
+
is its predecessor.
|
|
37
|
+
"""
|
|
38
|
+
# Newer systems have /etc/os-release with relevant distro info
|
|
39
|
+
__my_distro__ = linecache.getline(__OS_RELEASE_FILE__, 3)[3:-1]
|
|
40
|
+
__my_distro_ver__ = linecache.getline(__OS_RELEASE_FILE__, 6)[12:-2]
|
|
41
|
+
|
|
42
|
+
# Older systems may have /etc/os-release instead
|
|
43
|
+
if not __my_distro__:
|
|
44
|
+
__my_distro__ = linecache.getline(__LSB_RELEASE_FILE__, 1)[11:-1]
|
|
45
|
+
__my_distro_ver__ = linecache.getline(__LSB_RELEASE_FILE__, 2)[16:-1]
|
|
46
|
+
|
|
47
|
+
# Instead of trying to parse distro specific files,
|
|
48
|
+
# warn the user ONNX Runtime may not work out of the box
|
|
49
|
+
__my_distro__ = __my_distro__.lower()
|
|
50
|
+
__my_distro_ver__ = __my_distro_ver__.lower()
|
|
51
|
+
elif __my_system__ == "darwin":
|
|
52
|
+
__my_distro__ = __my_system__
|
|
53
|
+
__my_distro_ver__ = platform.release().lower()
|
|
54
|
+
|
|
55
|
+
if int(__my_distro_ver__.split(".")[0]) < 11:
|
|
56
|
+
warnings.warn(
|
|
57
|
+
f"Unsupported macOS version ({__my_distro_ver__}). ONNX Runtime supports macOS 11.0 or later."
|
|
58
|
+
)
|
|
59
|
+
elif __my_system__ == "aix":
|
|
60
|
+
import subprocess # noqa: PLC0415
|
|
61
|
+
|
|
62
|
+
returned_output = subprocess.check_output("oslevel")
|
|
63
|
+
__my_distro_ver__str = returned_output.decode("utf-8")
|
|
64
|
+
__my_distro_ver = __my_distro_ver__str[:3]
|
|
65
|
+
else:
|
|
66
|
+
warnings.warn(
|
|
67
|
+
f"Unsupported platform ({__my_system__}). ONNX Runtime supports Linux, macOS, AIX and Windows platforms, only."
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def get_package_name_and_version_info():
|
|
72
|
+
package_name = ""
|
|
73
|
+
version = ""
|
|
74
|
+
cuda_version = ""
|
|
75
|
+
|
|
76
|
+
try:
|
|
77
|
+
from .build_and_package_info import __version__ as version # noqa: PLC0415
|
|
78
|
+
from .build_and_package_info import package_name # noqa: PLC0415
|
|
79
|
+
|
|
80
|
+
try: # noqa: SIM105
|
|
81
|
+
from .build_and_package_info import cuda_version # noqa: PLC0415
|
|
82
|
+
except ImportError:
|
|
83
|
+
# cuda_version is optional. For example, cpu only package does not have the attribute.
|
|
84
|
+
pass
|
|
85
|
+
except Exception as e:
|
|
86
|
+
warnings.warn("WARNING: failed to collect package name and version info")
|
|
87
|
+
print(e)
|
|
88
|
+
|
|
89
|
+
return package_name, version, cuda_version
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def check_training_module():
|
|
93
|
+
import_ortmodule_exception = None
|
|
94
|
+
|
|
95
|
+
has_ortmodule = False
|
|
96
|
+
try:
|
|
97
|
+
from onnxruntime.training.ortmodule import ORTModule # noqa: F401, PLC0415
|
|
98
|
+
|
|
99
|
+
has_ortmodule = True
|
|
100
|
+
except ImportError:
|
|
101
|
+
# ORTModule not present
|
|
102
|
+
has_ortmodule = False
|
|
103
|
+
except Exception as e:
|
|
104
|
+
# this may happen if Cuda is not installed, we want to raise it after
|
|
105
|
+
# for any exception other than not having ortmodule, we want to continue
|
|
106
|
+
# device version validation and raise the exception after.
|
|
107
|
+
try:
|
|
108
|
+
from onnxruntime.training.ortmodule._fallback import ORTModuleInitException # noqa: PLC0415
|
|
109
|
+
|
|
110
|
+
if isinstance(e, ORTModuleInitException):
|
|
111
|
+
# ORTModule is present but not ready to run yet
|
|
112
|
+
has_ortmodule = True
|
|
113
|
+
except Exception:
|
|
114
|
+
# ORTModule not present
|
|
115
|
+
has_ortmodule = False
|
|
116
|
+
|
|
117
|
+
if not has_ortmodule:
|
|
118
|
+
import_ortmodule_exception = e
|
|
119
|
+
|
|
120
|
+
# collect onnxruntime package name, version, and cuda version
|
|
121
|
+
package_name, version, cuda_version = get_package_name_and_version_info()
|
|
122
|
+
|
|
123
|
+
if has_ortmodule and cuda_version:
|
|
124
|
+
try:
|
|
125
|
+
# collect cuda library build info. the library info may not be available
|
|
126
|
+
# when the build environment has none or multiple libraries installed
|
|
127
|
+
try:
|
|
128
|
+
from .build_and_package_info import cudart_version # noqa: PLC0415
|
|
129
|
+
except ImportError:
|
|
130
|
+
warnings.warn("WARNING: failed to get cudart_version from onnxruntime build info.")
|
|
131
|
+
cudart_version = None
|
|
132
|
+
|
|
133
|
+
def print_build_package_info():
|
|
134
|
+
warnings.warn(f"onnxruntime training package info: package_name: {package_name}")
|
|
135
|
+
warnings.warn(f"onnxruntime training package info: __version__: {version}")
|
|
136
|
+
warnings.warn(f"onnxruntime training package info: cuda_version: {cuda_version}")
|
|
137
|
+
warnings.warn(f"onnxruntime build info: cudart_version: {cudart_version}")
|
|
138
|
+
|
|
139
|
+
# collection cuda library info from current environment.
|
|
140
|
+
from onnxruntime.capi.onnxruntime_collect_build_info import find_cudart_versions # noqa: PLC0415
|
|
141
|
+
|
|
142
|
+
local_cudart_versions = find_cudart_versions(build_env=False, build_cuda_version=cuda_version)
|
|
143
|
+
if cudart_version and local_cudart_versions and cudart_version not in local_cudart_versions:
|
|
144
|
+
print_build_package_info()
|
|
145
|
+
warnings.warn("WARNING: failed to find cudart version that matches onnxruntime build info")
|
|
146
|
+
warnings.warn(f"WARNING: found cudart versions: {local_cudart_versions}")
|
|
147
|
+
except Exception as e:
|
|
148
|
+
warnings.warn("WARNING: failed to collect onnxruntime version and build info")
|
|
149
|
+
print(e)
|
|
150
|
+
|
|
151
|
+
if import_ortmodule_exception:
|
|
152
|
+
raise import_ortmodule_exception
|
|
153
|
+
|
|
154
|
+
return has_ortmodule, package_name, version, cuda_version
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
"""
|
|
4
|
+
Short examples used in the documentation.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import os
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def get_example(name):
|
|
11
|
+
"""
|
|
12
|
+
Retrieves the absolute file name of an example.
|
|
13
|
+
"""
|
|
14
|
+
this = os.path.abspath(os.path.dirname(__file__))
|
|
15
|
+
full = os.path.join(this, name)
|
|
16
|
+
if not os.path.exists(full):
|
|
17
|
+
raise FileNotFoundError(f"Unable to find example '{name}'")
|
|
18
|
+
return full
|
|
Binary file
|
|
Binary file
|
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
# automatically generated by the FlatBuffers compiler, do not modify
|
|
2
|
+
|
|
3
|
+
# namespace: CalTableFlatBuffers
|
|
4
|
+
|
|
5
|
+
import flatbuffers
|
|
6
|
+
from flatbuffers.compat import import_numpy
|
|
7
|
+
|
|
8
|
+
np = import_numpy()
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class KeyValue:
|
|
12
|
+
__slots__ = ["_tab"]
|
|
13
|
+
|
|
14
|
+
@classmethod
|
|
15
|
+
def GetRootAs(cls, buf, offset=0): # noqa: N802
|
|
16
|
+
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
|
|
17
|
+
x = KeyValue()
|
|
18
|
+
x.Init(buf, n + offset)
|
|
19
|
+
return x
|
|
20
|
+
|
|
21
|
+
@classmethod
|
|
22
|
+
def GetRootAsKeyValue(cls, buf, offset=0): # noqa: N802
|
|
23
|
+
"""This method is deprecated. Please switch to GetRootAs."""
|
|
24
|
+
return cls.GetRootAs(buf, offset)
|
|
25
|
+
|
|
26
|
+
# KeyValue
|
|
27
|
+
def Init(self, buf, pos): # noqa: N802
|
|
28
|
+
self._tab = flatbuffers.table.Table(buf, pos)
|
|
29
|
+
|
|
30
|
+
# KeyValue
|
|
31
|
+
def Key(self): # noqa: N802
|
|
32
|
+
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
|
|
33
|
+
if o != 0:
|
|
34
|
+
return self._tab.String(o + self._tab.Pos)
|
|
35
|
+
return None
|
|
36
|
+
|
|
37
|
+
# KeyValue
|
|
38
|
+
def Value(self): # noqa: N802
|
|
39
|
+
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
|
|
40
|
+
if o != 0:
|
|
41
|
+
return self._tab.String(o + self._tab.Pos)
|
|
42
|
+
return None
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def Start(builder): # noqa: N802
|
|
46
|
+
builder.StartObject(2)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def KeyValueStart(builder): # noqa: N802
|
|
50
|
+
"""This method is deprecated. Please switch to Start."""
|
|
51
|
+
return Start(builder)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def AddKey(builder, key): # noqa: N802
|
|
55
|
+
builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(key), 0)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def KeyValueAddKey(builder, key): # noqa: N802
|
|
59
|
+
"""This method is deprecated. Please switch to AddKey."""
|
|
60
|
+
return AddKey(builder, key)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def AddValue(builder, value): # noqa: N802
|
|
64
|
+
builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(value), 0)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def KeyValueAddValue(builder, value): # noqa: N802
|
|
68
|
+
"""This method is deprecated. Please switch to AddValue."""
|
|
69
|
+
return AddValue(builder, value)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def End(builder): # noqa: N802
|
|
73
|
+
return builder.EndObject()
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def KeyValueEnd(builder): # noqa: N802
|
|
77
|
+
"""This method is deprecated. Please switch to End."""
|
|
78
|
+
return End(builder)
|
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
# automatically generated by the FlatBuffers compiler, do not modify
|
|
2
|
+
|
|
3
|
+
# namespace: CalTableFlatBuffers
|
|
4
|
+
|
|
5
|
+
import flatbuffers
|
|
6
|
+
from flatbuffers.compat import import_numpy
|
|
7
|
+
|
|
8
|
+
np = import_numpy()
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class TrtTable:
|
|
12
|
+
__slots__ = ["_tab"]
|
|
13
|
+
|
|
14
|
+
@classmethod
|
|
15
|
+
def GetRootAs(cls, buf, offset=0): # noqa: N802
|
|
16
|
+
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
|
|
17
|
+
x = TrtTable()
|
|
18
|
+
x.Init(buf, n + offset)
|
|
19
|
+
return x
|
|
20
|
+
|
|
21
|
+
@classmethod
|
|
22
|
+
def GetRootAsTrtTable(cls, buf, offset=0): # noqa: N802
|
|
23
|
+
"""This method is deprecated. Please switch to GetRootAs."""
|
|
24
|
+
return cls.GetRootAs(buf, offset)
|
|
25
|
+
|
|
26
|
+
# TrtTable
|
|
27
|
+
def Init(self, buf, pos): # noqa: N802
|
|
28
|
+
self._tab = flatbuffers.table.Table(buf, pos)
|
|
29
|
+
|
|
30
|
+
# TrtTable
|
|
31
|
+
def Dict(self, j): # noqa: N802
|
|
32
|
+
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
|
|
33
|
+
if o != 0:
|
|
34
|
+
x = self._tab.Vector(o)
|
|
35
|
+
x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * 4
|
|
36
|
+
x = self._tab.Indirect(x)
|
|
37
|
+
from onnxruntime.quantization.CalTableFlatBuffers.KeyValue import KeyValue # noqa: PLC0415
|
|
38
|
+
|
|
39
|
+
obj = KeyValue()
|
|
40
|
+
obj.Init(self._tab.Bytes, x)
|
|
41
|
+
return obj
|
|
42
|
+
return None
|
|
43
|
+
|
|
44
|
+
# TrtTable
|
|
45
|
+
def DictLength(self): # noqa: N802
|
|
46
|
+
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
|
|
47
|
+
if o != 0:
|
|
48
|
+
return self._tab.VectorLen(o)
|
|
49
|
+
return 0
|
|
50
|
+
|
|
51
|
+
# TrtTable
|
|
52
|
+
def DictIsNone(self): # noqa: N802
|
|
53
|
+
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
|
|
54
|
+
return o == 0
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def Start(builder): # noqa: N802
|
|
58
|
+
builder.StartObject(1)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def TrtTableStart(builder): # noqa: N802
|
|
62
|
+
"""This method is deprecated. Please switch to Start."""
|
|
63
|
+
return Start(builder)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def AddDict(builder, dict): # noqa: N802
|
|
67
|
+
builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(dict), 0)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def TrtTableAddDict(builder, dict): # noqa: N802
|
|
71
|
+
"""This method is deprecated. Please switch to AddDict."""
|
|
72
|
+
return AddDict(builder, dict)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def StartDictVector(builder, numElems): # noqa: N802
|
|
76
|
+
return builder.StartVector(4, numElems, 4)
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def TrtTableStartDictVector(builder, numElems): # noqa: N802
|
|
80
|
+
"""This method is deprecated. Please switch to Start."""
|
|
81
|
+
return StartDictVector(builder, numElems)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def End(builder): # noqa: N802
|
|
85
|
+
return builder.EndObject()
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def TrtTableEnd(builder): # noqa: N802
|
|
89
|
+
"""This method is deprecated. Please switch to End."""
|
|
90
|
+
return End(builder)
|
|
File without changes
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
from .calibrate import ( # noqa: F401
|
|
2
|
+
CalibraterBase,
|
|
3
|
+
CalibrationDataReader,
|
|
4
|
+
CalibrationMethod,
|
|
5
|
+
MinMaxCalibrater,
|
|
6
|
+
create_calibrator,
|
|
7
|
+
)
|
|
8
|
+
from .qdq_quantizer import QDQQuantizer # noqa: F401
|
|
9
|
+
from .quant_utils import QuantFormat, QuantType, write_calibration_table # noqa: F401
|
|
10
|
+
from .quantize import (
|
|
11
|
+
DynamicQuantConfig, # noqa: F401
|
|
12
|
+
QuantizationMode, # noqa: F401
|
|
13
|
+
StaticQuantConfig, # noqa: F401
|
|
14
|
+
get_qdq_config, # noqa: F401
|
|
15
|
+
quantize, # noqa: F401
|
|
16
|
+
quantize_dynamic, # noqa: F401
|
|
17
|
+
quantize_static, # noqa: F401
|
|
18
|
+
)
|
|
19
|
+
from .shape_inference import quant_pre_process # noqa: F401
|