onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxruntime/LICENSE +21 -0
- onnxruntime/Privacy.md +21 -0
- onnxruntime/ThirdPartyNotices.txt +6121 -0
- onnxruntime/__init__.py +418 -0
- onnxruntime/backend/__init__.py +6 -0
- onnxruntime/backend/backend.py +175 -0
- onnxruntime/backend/backend_rep.py +52 -0
- onnxruntime/capi/DirectML.dll +0 -0
- onnxruntime/capi/__init__.py +4 -0
- onnxruntime/capi/_ld_preload.py +7 -0
- onnxruntime/capi/_pybind_state.py +33 -0
- onnxruntime/capi/build_and_package_info.py +2 -0
- onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
- onnxruntime/capi/onnxruntime.dll +0 -0
- onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
- onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
- onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
- onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
- onnxruntime/capi/onnxruntime_validation.py +154 -0
- onnxruntime/capi/version_info.py +2 -0
- onnxruntime/datasets/__init__.py +18 -0
- onnxruntime/datasets/logreg_iris.onnx +0 -0
- onnxruntime/datasets/mul_1.onnx +0 -0
- onnxruntime/datasets/sigmoid.onnx +13 -0
- onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
- onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
- onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
- onnxruntime/quantization/__init__.py +19 -0
- onnxruntime/quantization/base_quantizer.py +529 -0
- onnxruntime/quantization/calibrate.py +1267 -0
- onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
- onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
- onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
- onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
- onnxruntime/quantization/fusions/__init__.py +4 -0
- onnxruntime/quantization/fusions/fusion.py +311 -0
- onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
- onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
- onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
- onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
- onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
- onnxruntime/quantization/neural_compressor/__init__.py +1 -0
- onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
- onnxruntime/quantization/neural_compressor/util.py +80 -0
- onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
- onnxruntime/quantization/onnx_model.py +600 -0
- onnxruntime/quantization/onnx_quantizer.py +1163 -0
- onnxruntime/quantization/operators/__init__.py +2 -0
- onnxruntime/quantization/operators/activation.py +119 -0
- onnxruntime/quantization/operators/argmax.py +18 -0
- onnxruntime/quantization/operators/attention.py +73 -0
- onnxruntime/quantization/operators/base_operator.py +26 -0
- onnxruntime/quantization/operators/binary_op.py +72 -0
- onnxruntime/quantization/operators/concat.py +62 -0
- onnxruntime/quantization/operators/conv.py +260 -0
- onnxruntime/quantization/operators/direct_q8.py +78 -0
- onnxruntime/quantization/operators/embed_layernorm.py +121 -0
- onnxruntime/quantization/operators/gather.py +64 -0
- onnxruntime/quantization/operators/gavgpool.py +62 -0
- onnxruntime/quantization/operators/gemm.py +172 -0
- onnxruntime/quantization/operators/lstm.py +121 -0
- onnxruntime/quantization/operators/matmul.py +231 -0
- onnxruntime/quantization/operators/maxpool.py +34 -0
- onnxruntime/quantization/operators/norm.py +40 -0
- onnxruntime/quantization/operators/pad.py +172 -0
- onnxruntime/quantization/operators/pooling.py +67 -0
- onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
- onnxruntime/quantization/operators/resize.py +34 -0
- onnxruntime/quantization/operators/softmax.py +74 -0
- onnxruntime/quantization/operators/split.py +63 -0
- onnxruntime/quantization/operators/where.py +87 -0
- onnxruntime/quantization/preprocess.py +141 -0
- onnxruntime/quantization/qdq_loss_debug.py +389 -0
- onnxruntime/quantization/qdq_quantizer.py +1477 -0
- onnxruntime/quantization/quant_utils.py +1051 -0
- onnxruntime/quantization/quantize.py +953 -0
- onnxruntime/quantization/registry.py +110 -0
- onnxruntime/quantization/shape_inference.py +204 -0
- onnxruntime/quantization/static_quantize_runner.py +256 -0
- onnxruntime/quantization/tensor_quant_overrides.py +520 -0
- onnxruntime/tools/__init__.py +10 -0
- onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
- onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
- onnxruntime/tools/file_utils.py +47 -0
- onnxruntime/tools/logger.py +11 -0
- onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
- onnxruntime/tools/mobile_helpers/__init__.py +0 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
- onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
- onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
- onnxruntime/tools/offline_tuning.py +169 -0
- onnxruntime/tools/onnx_model_utils.py +416 -0
- onnxruntime/tools/onnx_randomizer.py +85 -0
- onnxruntime/tools/onnxruntime_test.py +164 -0
- onnxruntime/tools/optimize_onnx_model.py +56 -0
- onnxruntime/tools/ort_format_model/__init__.py +27 -0
- onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
- onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
- onnxruntime/tools/ort_format_model/types.py +85 -0
- onnxruntime/tools/ort_format_model/utils.py +61 -0
- onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
- onnxruntime/tools/pytorch_export_helpers.py +131 -0
- onnxruntime/tools/qdq_helpers/__init__.py +0 -0
- onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
- onnxruntime/tools/qnn/add_trans_cast.py +292 -0
- onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
- onnxruntime/tools/qnn/preprocess.py +165 -0
- onnxruntime/tools/reduced_build_config_parser.py +203 -0
- onnxruntime/tools/remove_initializer_from_input.py +37 -0
- onnxruntime/tools/symbolic_shape_infer.py +3094 -0
- onnxruntime/tools/update_onnx_opset.py +31 -0
- onnxruntime/transformers/__init__.py +8 -0
- onnxruntime/transformers/affinity_helper.py +40 -0
- onnxruntime/transformers/benchmark.py +942 -0
- onnxruntime/transformers/benchmark_helper.py +643 -0
- onnxruntime/transformers/bert_perf_test.py +629 -0
- onnxruntime/transformers/bert_test_data.py +641 -0
- onnxruntime/transformers/compare_bert_results.py +256 -0
- onnxruntime/transformers/constants.py +47 -0
- onnxruntime/transformers/convert_generation.py +3605 -0
- onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
- onnxruntime/transformers/convert_to_packing_mode.py +385 -0
- onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
- onnxruntime/transformers/float16.py +501 -0
- onnxruntime/transformers/fusion_attention.py +1189 -0
- onnxruntime/transformers/fusion_attention_clip.py +340 -0
- onnxruntime/transformers/fusion_attention_sam2.py +533 -0
- onnxruntime/transformers/fusion_attention_unet.py +1307 -0
- onnxruntime/transformers/fusion_attention_vae.py +300 -0
- onnxruntime/transformers/fusion_bart_attention.py +435 -0
- onnxruntime/transformers/fusion_base.py +141 -0
- onnxruntime/transformers/fusion_bias_add.py +57 -0
- onnxruntime/transformers/fusion_biasgelu.py +66 -0
- onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
- onnxruntime/transformers/fusion_conformer_attention.py +222 -0
- onnxruntime/transformers/fusion_constant_fold.py +144 -0
- onnxruntime/transformers/fusion_embedlayer.py +810 -0
- onnxruntime/transformers/fusion_fastgelu.py +492 -0
- onnxruntime/transformers/fusion_gelu.py +258 -0
- onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
- onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
- onnxruntime/transformers/fusion_gpt_attention.py +546 -0
- onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
- onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
- onnxruntime/transformers/fusion_group_norm.py +180 -0
- onnxruntime/transformers/fusion_layernorm.py +489 -0
- onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
- onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
- onnxruntime/transformers/fusion_options.py +340 -0
- onnxruntime/transformers/fusion_qordered_attention.py +420 -0
- onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
- onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
- onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
- onnxruntime/transformers/fusion_quickgelu.py +74 -0
- onnxruntime/transformers/fusion_reshape.py +173 -0
- onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
- onnxruntime/transformers/fusion_shape.py +109 -0
- onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
- onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
- onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
- onnxruntime/transformers/fusion_transpose.py +167 -0
- onnxruntime/transformers/fusion_utils.py +321 -0
- onnxruntime/transformers/huggingface_models.py +74 -0
- onnxruntime/transformers/import_utils.py +20 -0
- onnxruntime/transformers/io_binding_helper.py +487 -0
- onnxruntime/transformers/large_model_exporter.py +395 -0
- onnxruntime/transformers/machine_info.py +230 -0
- onnxruntime/transformers/metrics.py +163 -0
- onnxruntime/transformers/models/bart/__init__.py +12 -0
- onnxruntime/transformers/models/bart/export.py +98 -0
- onnxruntime/transformers/models/bert/__init__.py +12 -0
- onnxruntime/transformers/models/bert/eval_squad.py +329 -0
- onnxruntime/transformers/models/gpt2/__init__.py +12 -0
- onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
- onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
- onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
- onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
- onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
- onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
- onnxruntime/transformers/models/llama/__init__.py +12 -0
- onnxruntime/transformers/models/llama/benchmark.py +700 -0
- onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
- onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
- onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
- onnxruntime/transformers/models/llama/dist_settings.py +57 -0
- onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
- onnxruntime/transformers/models/llama/llama_parity.py +343 -0
- onnxruntime/transformers/models/llama/llama_torch.py +47 -0
- onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
- onnxruntime/transformers/models/longformer/__init__.py +12 -0
- onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
- onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
- onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
- onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
- onnxruntime/transformers/models/phi2/__init__.py +12 -0
- onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
- onnxruntime/transformers/models/phi2/inference_example.py +414 -0
- onnxruntime/transformers/models/sam2/__init__.py +12 -0
- onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
- onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
- onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
- onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
- onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
- onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
- onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
- onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
- onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
- onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
- onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
- onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
- onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
- onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
- onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
- onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
- onnxruntime/transformers/models/t5/__init__.py +12 -0
- onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
- onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
- onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
- onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
- onnxruntime/transformers/models/t5/t5_helper.py +302 -0
- onnxruntime/transformers/models/whisper/__init__.py +12 -0
- onnxruntime/transformers/models/whisper/benchmark.py +585 -0
- onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
- onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
- onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
- onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
- onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
- onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
- onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
- onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
- onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
- onnxruntime/transformers/onnx_exporter.py +719 -0
- onnxruntime/transformers/onnx_model.py +1636 -0
- onnxruntime/transformers/onnx_model_bart.py +141 -0
- onnxruntime/transformers/onnx_model_bert.py +488 -0
- onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
- onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
- onnxruntime/transformers/onnx_model_clip.py +42 -0
- onnxruntime/transformers/onnx_model_conformer.py +32 -0
- onnxruntime/transformers/onnx_model_gpt2.py +101 -0
- onnxruntime/transformers/onnx_model_mmdit.py +112 -0
- onnxruntime/transformers/onnx_model_phi.py +929 -0
- onnxruntime/transformers/onnx_model_sam2.py +137 -0
- onnxruntime/transformers/onnx_model_t5.py +985 -0
- onnxruntime/transformers/onnx_model_tnlr.py +226 -0
- onnxruntime/transformers/onnx_model_unet.py +258 -0
- onnxruntime/transformers/onnx_model_vae.py +42 -0
- onnxruntime/transformers/onnx_utils.py +55 -0
- onnxruntime/transformers/optimizer.py +620 -0
- onnxruntime/transformers/past_helper.py +149 -0
- onnxruntime/transformers/profile_result_processor.py +358 -0
- onnxruntime/transformers/profiler.py +434 -0
- onnxruntime/transformers/quantize_helper.py +76 -0
- onnxruntime/transformers/shape_infer_helper.py +121 -0
- onnxruntime/transformers/shape_optimizer.py +400 -0
- onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
- onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
- onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
- onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
- onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
- onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,609 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License. See License.txt in the project root for
|
|
4
|
+
# license information.
|
|
5
|
+
# --------------------------------------------------------------------------
|
|
6
|
+
|
|
7
|
+
import argparse
|
|
8
|
+
import logging
|
|
9
|
+
import os
|
|
10
|
+
import warnings
|
|
11
|
+
|
|
12
|
+
import onnx
|
|
13
|
+
import torch
|
|
14
|
+
from benchmark_helper import Precision, create_onnxruntime_session, prepare_environment, setup_logger
|
|
15
|
+
from whisper_chain import chain_model
|
|
16
|
+
from whisper_encoder import WhisperEncoder
|
|
17
|
+
from whisper_helper import PRETRAINED_WHISPER_MODELS, WhisperHelper
|
|
18
|
+
|
|
19
|
+
from onnxruntime.quantization.matmul_nbits_quantizer import (
|
|
20
|
+
KQuantWeightOnlyQuantConfig,
|
|
21
|
+
MatMulNBitsQuantizer,
|
|
22
|
+
QuantFormat,
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
logger = logging.getLogger("")
|
|
26
|
+
|
|
27
|
+
PROVIDERS = {
|
|
28
|
+
"cpu": "CPUExecutionProvider",
|
|
29
|
+
"cuda": "CUDAExecutionProvider",
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def parse_arguments(argv=None):
|
|
34
|
+
parser = argparse.ArgumentParser()
|
|
35
|
+
|
|
36
|
+
conversion_args = parser.add_argument_group("Conversion Process Args")
|
|
37
|
+
optional_inputs = parser.add_argument_group("Optional Inputs (for WhisperBeamSearch op)")
|
|
38
|
+
optional_outputs = parser.add_argument_group("Optional Outputs (for WhisperBeamSearch op)")
|
|
39
|
+
quant_args = parser.add_argument_group("INT8 Quantization Args")
|
|
40
|
+
|
|
41
|
+
#################################
|
|
42
|
+
# Conversion options for Whisper
|
|
43
|
+
#################################
|
|
44
|
+
|
|
45
|
+
conversion_args.add_argument(
|
|
46
|
+
"-m",
|
|
47
|
+
"--model_name_or_path",
|
|
48
|
+
required=False,
|
|
49
|
+
default=PRETRAINED_WHISPER_MODELS[0],
|
|
50
|
+
type=str,
|
|
51
|
+
help="Model path, or pretrained model name in the list: " + ", ".join(PRETRAINED_WHISPER_MODELS),
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
conversion_args.add_argument(
|
|
55
|
+
"--model_impl",
|
|
56
|
+
required=False,
|
|
57
|
+
default="hf",
|
|
58
|
+
choices=["hf", "openai"],
|
|
59
|
+
type=str,
|
|
60
|
+
help="Select implementation for export of encoder and decoder subgraphs",
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
conversion_args.add_argument(
|
|
64
|
+
"--cache_dir",
|
|
65
|
+
required=False,
|
|
66
|
+
type=str,
|
|
67
|
+
default=os.path.join(".", "cache_models"),
|
|
68
|
+
help="Directory to cache pre-trained models",
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
conversion_args.add_argument(
|
|
72
|
+
"--output",
|
|
73
|
+
required=False,
|
|
74
|
+
type=str,
|
|
75
|
+
default=os.path.join(".", "onnx_models"),
|
|
76
|
+
help="Output directory",
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
conversion_args.add_argument(
|
|
80
|
+
"-o",
|
|
81
|
+
"--optimize_onnx",
|
|
82
|
+
required=False,
|
|
83
|
+
action="store_true",
|
|
84
|
+
help="Use optimizer.py to optimize onnx model",
|
|
85
|
+
)
|
|
86
|
+
conversion_args.set_defaults(optimize_onnx=False)
|
|
87
|
+
|
|
88
|
+
conversion_args.add_argument(
|
|
89
|
+
"--use_gpu",
|
|
90
|
+
required=False,
|
|
91
|
+
action="store_true",
|
|
92
|
+
help="Use GPU for model inference",
|
|
93
|
+
)
|
|
94
|
+
conversion_args.set_defaults(use_gpu=False)
|
|
95
|
+
|
|
96
|
+
conversion_args.add_argument(
|
|
97
|
+
"-p",
|
|
98
|
+
"--precision",
|
|
99
|
+
required=False,
|
|
100
|
+
type=Precision,
|
|
101
|
+
default=Precision.FLOAT32,
|
|
102
|
+
choices=[Precision.FLOAT32, Precision.FLOAT16, Precision.INT8, Precision.INT4],
|
|
103
|
+
help="Precision of model to run. fp32 for full precision, fp16 for half precision, int8/int4 for quantization",
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
conversion_args.add_argument(
|
|
107
|
+
"--use_int64_inputs",
|
|
108
|
+
required=False,
|
|
109
|
+
action="store_true",
|
|
110
|
+
help="Use int64 instead of int32 for input_ids and attention_mask.",
|
|
111
|
+
)
|
|
112
|
+
conversion_args.set_defaults(use_int64_inputs=False)
|
|
113
|
+
|
|
114
|
+
conversion_args.add_argument(
|
|
115
|
+
"-r",
|
|
116
|
+
"--provider",
|
|
117
|
+
required=False,
|
|
118
|
+
type=str,
|
|
119
|
+
default="cpu",
|
|
120
|
+
choices=list(PROVIDERS.keys()),
|
|
121
|
+
help="Provider to benchmark. Default is CPUExecutionProvider.",
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
conversion_args.add_argument(
|
|
125
|
+
"--verbose",
|
|
126
|
+
required=False,
|
|
127
|
+
action="store_true",
|
|
128
|
+
help="Enable verbose logging",
|
|
129
|
+
)
|
|
130
|
+
conversion_args.set_defaults(verbose=False)
|
|
131
|
+
|
|
132
|
+
conversion_args.add_argument(
|
|
133
|
+
"-e",
|
|
134
|
+
"--use_external_data_format",
|
|
135
|
+
required=False,
|
|
136
|
+
action="store_true",
|
|
137
|
+
help="Save weights in external file. Necessary for 'small', 'medium', and 'large' models. Optional for 'tiny' and 'base' models.",
|
|
138
|
+
)
|
|
139
|
+
conversion_args.set_defaults(use_external_data_format=False)
|
|
140
|
+
|
|
141
|
+
conversion_args.add_argument(
|
|
142
|
+
"-w",
|
|
143
|
+
"--overwrite",
|
|
144
|
+
required=False,
|
|
145
|
+
action="store_true",
|
|
146
|
+
help="Overwrite existing ONNX model",
|
|
147
|
+
)
|
|
148
|
+
conversion_args.set_defaults(overwrite=False)
|
|
149
|
+
|
|
150
|
+
conversion_args.add_argument(
|
|
151
|
+
"--separate_encoder_and_decoder_init",
|
|
152
|
+
required=False,
|
|
153
|
+
action="store_true",
|
|
154
|
+
help="Do not merge encoder and decoder init to initialize past KV caches. Output 3 instead of 2 ONNX models.",
|
|
155
|
+
)
|
|
156
|
+
conversion_args.set_defaults(separate_encoder_and_decoder_init=False)
|
|
157
|
+
|
|
158
|
+
conversion_args.add_argument(
|
|
159
|
+
"--no_beam_search_op",
|
|
160
|
+
required=False,
|
|
161
|
+
action="store_true",
|
|
162
|
+
help="Do not produce model with WhisperBeamSearch op, which chains encdecinit and decoder models into one op.",
|
|
163
|
+
)
|
|
164
|
+
conversion_args.set_defaults(no_beam_search_op=False)
|
|
165
|
+
|
|
166
|
+
conversion_args.add_argument(
|
|
167
|
+
"--use_decoder_masked_mha",
|
|
168
|
+
required=False,
|
|
169
|
+
action="store_true",
|
|
170
|
+
help="Use DecoderMaskedMultiHeadAttention kernel for improved performance. This is currently an experimental feature.",
|
|
171
|
+
)
|
|
172
|
+
conversion_args.set_defaults(use_decoder_masked_mha=False)
|
|
173
|
+
|
|
174
|
+
#############################################################
|
|
175
|
+
# Optional inputs for Whisper
|
|
176
|
+
# (listed below in the order that WhisperBeamSearch expects)
|
|
177
|
+
#############################################################
|
|
178
|
+
|
|
179
|
+
optional_inputs.add_argument(
|
|
180
|
+
"-v",
|
|
181
|
+
"--use_vocab_mask",
|
|
182
|
+
required=False,
|
|
183
|
+
action="store_true",
|
|
184
|
+
help="Use vocab_mask as an extra graph input to enable specific logits processing",
|
|
185
|
+
)
|
|
186
|
+
optional_inputs.set_defaults(use_vocab_mask=False)
|
|
187
|
+
|
|
188
|
+
optional_inputs.add_argument(
|
|
189
|
+
"-u",
|
|
190
|
+
"--use_prefix_vocab_mask",
|
|
191
|
+
required=False,
|
|
192
|
+
action="store_true",
|
|
193
|
+
help="Use prefix_vocab_mask as an extra graph input to enable specific logits processing",
|
|
194
|
+
)
|
|
195
|
+
optional_inputs.set_defaults(use_prefix_vocab_mask=False)
|
|
196
|
+
|
|
197
|
+
optional_inputs.add_argument(
|
|
198
|
+
"-f",
|
|
199
|
+
"--use_forced_decoder_ids",
|
|
200
|
+
required=False,
|
|
201
|
+
action="store_true",
|
|
202
|
+
help="Use decoder_input_ids as an extra graph input to the beam search op",
|
|
203
|
+
)
|
|
204
|
+
optional_inputs.set_defaults(use_forced_decoder_ids=False)
|
|
205
|
+
|
|
206
|
+
optional_inputs.add_argument(
|
|
207
|
+
"-l",
|
|
208
|
+
"--use_logits_processor",
|
|
209
|
+
required=False,
|
|
210
|
+
action="store_true",
|
|
211
|
+
help="Use logits_processor as an extra graph input to enable specific logits processing",
|
|
212
|
+
)
|
|
213
|
+
optional_inputs.set_defaults(use_specific_logits_processor=False)
|
|
214
|
+
|
|
215
|
+
optional_inputs.add_argument(
|
|
216
|
+
"--collect_cross_qk",
|
|
217
|
+
required=False,
|
|
218
|
+
action="store_true",
|
|
219
|
+
help="Beam search model collect stacked cross QK.",
|
|
220
|
+
)
|
|
221
|
+
optional_inputs.set_defaults(collect_cross_qk=False)
|
|
222
|
+
|
|
223
|
+
optional_inputs.add_argument(
|
|
224
|
+
"--extra_decoding_ids",
|
|
225
|
+
required=False,
|
|
226
|
+
action="store_true",
|
|
227
|
+
help="Need extra starting decoding ids for some feature like cross qk. Default if false.",
|
|
228
|
+
)
|
|
229
|
+
optional_inputs.set_defaults(extra_decoding_ids=False)
|
|
230
|
+
|
|
231
|
+
optional_inputs.add_argument(
|
|
232
|
+
"-t",
|
|
233
|
+
"--use_temperature",
|
|
234
|
+
required=False,
|
|
235
|
+
action="store_true",
|
|
236
|
+
help="Use temperature as an extra graph input for the WhisperBeamSearch op",
|
|
237
|
+
)
|
|
238
|
+
optional_inputs.set_defaults(use_temperature=False)
|
|
239
|
+
|
|
240
|
+
optional_inputs.add_argument(
|
|
241
|
+
"--no_repeat_ngram_size",
|
|
242
|
+
type=int,
|
|
243
|
+
default=0,
|
|
244
|
+
help="default to 0",
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
#############################################################
|
|
248
|
+
# Optional outputs for Whisper
|
|
249
|
+
# (listed below in the order that WhisperBeamSearch expects)
|
|
250
|
+
#############################################################
|
|
251
|
+
|
|
252
|
+
optional_outputs.add_argument(
|
|
253
|
+
"--output_sequence_scores",
|
|
254
|
+
required=False,
|
|
255
|
+
action="store_true",
|
|
256
|
+
help="Beam search model output scores for each generated sequence.",
|
|
257
|
+
)
|
|
258
|
+
optional_outputs.set_defaults(output_sequence_scores=False)
|
|
259
|
+
|
|
260
|
+
optional_outputs.add_argument(
|
|
261
|
+
"--output_scores",
|
|
262
|
+
required=False,
|
|
263
|
+
action="store_true",
|
|
264
|
+
help="Beam search model output scores over vocab per generated token.",
|
|
265
|
+
)
|
|
266
|
+
optional_outputs.set_defaults(output_scores=False)
|
|
267
|
+
|
|
268
|
+
optional_outputs.add_argument(
|
|
269
|
+
"--output_cross_qk",
|
|
270
|
+
required=False,
|
|
271
|
+
action="store_true",
|
|
272
|
+
help="Beam search model output collected qk as output. Also hint collect_cross_qk",
|
|
273
|
+
)
|
|
274
|
+
optional_outputs.set_defaults(output_cross_qk=False)
|
|
275
|
+
|
|
276
|
+
optional_outputs.add_argument(
|
|
277
|
+
"--cross_qk_onnx_model",
|
|
278
|
+
required=False,
|
|
279
|
+
type=str,
|
|
280
|
+
default=None,
|
|
281
|
+
help="The model which consumes cross_qk outputs.",
|
|
282
|
+
)
|
|
283
|
+
|
|
284
|
+
optional_outputs.add_argument(
|
|
285
|
+
"--output_no_speech_probs",
|
|
286
|
+
required=False,
|
|
287
|
+
action="store_true",
|
|
288
|
+
help="Beam search model output no speech probs which is computed from the encoder/context-decoder graph.",
|
|
289
|
+
)
|
|
290
|
+
optional_outputs.set_defaults(output_no_speech_probs=False)
|
|
291
|
+
|
|
292
|
+
###################################
|
|
293
|
+
# Quantization options for Whisper
|
|
294
|
+
###################################
|
|
295
|
+
|
|
296
|
+
quant_args.add_argument(
|
|
297
|
+
"--accuracy_level",
|
|
298
|
+
default=0,
|
|
299
|
+
required=False,
|
|
300
|
+
type=int,
|
|
301
|
+
help="Accuracy level of the 4-bit quantized MatMul computation.",
|
|
302
|
+
)
|
|
303
|
+
|
|
304
|
+
quant_args.add_argument(
|
|
305
|
+
"--quantize_symmetric",
|
|
306
|
+
required=False,
|
|
307
|
+
action="store_true",
|
|
308
|
+
help="Quantize weights symmetrically",
|
|
309
|
+
)
|
|
310
|
+
quant_args.set_defaults(quantize_symmetric=False)
|
|
311
|
+
|
|
312
|
+
args = parser.parse_args(argv)
|
|
313
|
+
|
|
314
|
+
# Collect cross QKs if either flag is enabled
|
|
315
|
+
args.collect_cross_qk = args.collect_cross_qk or args.output_cross_qk
|
|
316
|
+
|
|
317
|
+
# FP32 CPU can be supported here once the DMMHA CPU kernel bugs are fixed
|
|
318
|
+
args.use_decoder_masked_mha = args.use_decoder_masked_mha and args.provider == "cuda"
|
|
319
|
+
|
|
320
|
+
return args
|
|
321
|
+
|
|
322
|
+
|
|
323
|
+
# quant_method is reserved for mixed precision in future
|
|
324
|
+
def make_quant_algo_config(precision, quant_method: str, matmul_nodes=None):
|
|
325
|
+
customized_weight_config = {}
|
|
326
|
+
quant_algo_config = None
|
|
327
|
+
|
|
328
|
+
# need to use k_quant for int8
|
|
329
|
+
if precision == Precision.INT8:
|
|
330
|
+
for node_name in matmul_nodes:
|
|
331
|
+
customized_weight_config[node_name] = {"bits": 8}
|
|
332
|
+
quant_algo_config = KQuantWeightOnlyQuantConfig(customized_weight_config=customized_weight_config)
|
|
333
|
+
else:
|
|
334
|
+
quant_algo_config = KQuantWeightOnlyQuantConfig(customized_weight_config=customized_weight_config)
|
|
335
|
+
|
|
336
|
+
return quant_algo_config
|
|
337
|
+
|
|
338
|
+
|
|
339
|
+
def export_onnx_models(
|
|
340
|
+
model_name_or_path,
|
|
341
|
+
model_impl,
|
|
342
|
+
cache_dir,
|
|
343
|
+
output_dir,
|
|
344
|
+
use_gpu,
|
|
345
|
+
use_external_data_format,
|
|
346
|
+
optimize_onnx,
|
|
347
|
+
precision,
|
|
348
|
+
verbose,
|
|
349
|
+
use_forced_decoder_ids: bool = False,
|
|
350
|
+
merge_encoder_and_decoder_init: bool = True,
|
|
351
|
+
no_beam_search_op: bool = False,
|
|
352
|
+
use_decoder_masked_mha: bool = False,
|
|
353
|
+
output_qk: bool = False,
|
|
354
|
+
overwrite: bool = False,
|
|
355
|
+
use_int32_inputs: bool = True,
|
|
356
|
+
accuracy_level: int = 0,
|
|
357
|
+
quantize_symmetric: bool = False,
|
|
358
|
+
provider: str = "cpu",
|
|
359
|
+
):
|
|
360
|
+
device = torch.device("cuda" if use_gpu else "cpu")
|
|
361
|
+
if not use_gpu:
|
|
362
|
+
accuracy_level = 4 # change to 4 for CPU EP
|
|
363
|
+
use_fp16_inputs = precision == Precision.FLOAT16 or (precision in (Precision.INT8, Precision.INT4) and use_gpu)
|
|
364
|
+
|
|
365
|
+
models = WhisperHelper.load_model(
|
|
366
|
+
model_name_or_path,
|
|
367
|
+
model_impl,
|
|
368
|
+
cache_dir,
|
|
369
|
+
device,
|
|
370
|
+
torch.float16 if use_fp16_inputs else torch.float32,
|
|
371
|
+
merge_encoder_and_decoder_init,
|
|
372
|
+
no_beam_search_op,
|
|
373
|
+
output_qk,
|
|
374
|
+
)
|
|
375
|
+
config = models["decoder"].config
|
|
376
|
+
|
|
377
|
+
if (not use_external_data_format) and (config.num_hidden_layers > 24):
|
|
378
|
+
logger.warning("You MUST pass `--use_external_data_format` because model size > 2GB")
|
|
379
|
+
raise Exception("Please pass `--use_external_data_format` for this model.")
|
|
380
|
+
|
|
381
|
+
output_paths = []
|
|
382
|
+
for name, model in models.items():
|
|
383
|
+
print(f"========> Handling {name} model......")
|
|
384
|
+
filename_suffix = "_" + name
|
|
385
|
+
|
|
386
|
+
onnx_path = WhisperHelper.get_onnx_path(
|
|
387
|
+
output_dir,
|
|
388
|
+
model_name_or_path,
|
|
389
|
+
suffix=filename_suffix,
|
|
390
|
+
new_folder=False,
|
|
391
|
+
)
|
|
392
|
+
|
|
393
|
+
# Export to ONNX
|
|
394
|
+
if overwrite or not os.path.exists(onnx_path):
|
|
395
|
+
logger.info(f"Exporting ONNX model to {onnx_path}")
|
|
396
|
+
WhisperHelper.export_onnx(
|
|
397
|
+
model,
|
|
398
|
+
onnx_path,
|
|
399
|
+
PROVIDERS[provider],
|
|
400
|
+
verbose,
|
|
401
|
+
use_external_data_format,
|
|
402
|
+
use_fp16_inputs=use_fp16_inputs,
|
|
403
|
+
use_int32_inputs=use_int32_inputs,
|
|
404
|
+
use_encoder_hidden_states=(name == "decoder_init"),
|
|
405
|
+
use_kv_cache_inputs=(name == "decoder"),
|
|
406
|
+
)
|
|
407
|
+
else:
|
|
408
|
+
logger.info(f"Skip exporting: existing ONNX model {onnx_path}")
|
|
409
|
+
|
|
410
|
+
# Optimize ONNX model
|
|
411
|
+
if optimize_onnx or precision != Precision.FLOAT32:
|
|
412
|
+
output_path = WhisperHelper.get_onnx_path(
|
|
413
|
+
output_dir,
|
|
414
|
+
model_name_or_path,
|
|
415
|
+
suffix=filename_suffix + "_" + str(precision),
|
|
416
|
+
new_folder=False,
|
|
417
|
+
)
|
|
418
|
+
|
|
419
|
+
if overwrite or not os.path.exists(output_path):
|
|
420
|
+
if optimize_onnx:
|
|
421
|
+
logger.info(f"Optimizing model to {output_path}")
|
|
422
|
+
WhisperHelper.optimize_onnx(
|
|
423
|
+
onnx_path,
|
|
424
|
+
output_path,
|
|
425
|
+
precision == Precision.FLOAT16,
|
|
426
|
+
model.config.encoder_attention_heads,
|
|
427
|
+
model.config.d_model,
|
|
428
|
+
model.config.decoder_layers,
|
|
429
|
+
use_external_data_format,
|
|
430
|
+
use_gpu=use_gpu,
|
|
431
|
+
provider=provider,
|
|
432
|
+
is_decoder=(name == "decoder"),
|
|
433
|
+
no_beam_search_op=no_beam_search_op,
|
|
434
|
+
use_decoder_masked_mha=use_decoder_masked_mha,
|
|
435
|
+
output_qk=output_qk,
|
|
436
|
+
)
|
|
437
|
+
# Remove old ONNX model and old data file
|
|
438
|
+
if os.path.exists(onnx_path):
|
|
439
|
+
os.remove(onnx_path)
|
|
440
|
+
if os.path.exists(onnx_path + ".data"):
|
|
441
|
+
os.remove(onnx_path + ".data")
|
|
442
|
+
onnx_path = output_path
|
|
443
|
+
|
|
444
|
+
if isinstance(model, WhisperEncoder):
|
|
445
|
+
model.verify_onnx(
|
|
446
|
+
onnx_path,
|
|
447
|
+
PROVIDERS[provider],
|
|
448
|
+
use_fp16_inputs=use_fp16_inputs,
|
|
449
|
+
)
|
|
450
|
+
else:
|
|
451
|
+
model.verify_onnx(
|
|
452
|
+
onnx_path,
|
|
453
|
+
PROVIDERS[provider],
|
|
454
|
+
use_fp16_inputs=use_fp16_inputs,
|
|
455
|
+
use_int32_inputs=use_int32_inputs,
|
|
456
|
+
)
|
|
457
|
+
|
|
458
|
+
if precision in (Precision.INT8, Precision.INT4):
|
|
459
|
+
onnx_model = onnx.load(onnx_path, load_external_data=True)
|
|
460
|
+
matmul_nodes = [node.name for node in onnx_model.graph.node if node.op_type == "MatMul"]
|
|
461
|
+
quant_algo_config = make_quant_algo_config(precision, "k_quant", matmul_nodes)
|
|
462
|
+
|
|
463
|
+
quant = MatMulNBitsQuantizer(
|
|
464
|
+
model=onnx_model,
|
|
465
|
+
block_size=32,
|
|
466
|
+
is_symmetric=quantize_symmetric,
|
|
467
|
+
accuracy_level=accuracy_level,
|
|
468
|
+
quant_format=QuantFormat.QOperator,
|
|
469
|
+
op_types_to_quantize=("MatMul",),
|
|
470
|
+
algo_config=quant_algo_config,
|
|
471
|
+
)
|
|
472
|
+
quant.process()
|
|
473
|
+
if os.path.exists(output_path):
|
|
474
|
+
os.remove(output_path)
|
|
475
|
+
if os.path.exists(output_path + ".data"):
|
|
476
|
+
os.remove(output_path + ".data")
|
|
477
|
+
onnx.save_model(
|
|
478
|
+
quant.model.model,
|
|
479
|
+
output_path,
|
|
480
|
+
save_as_external_data=True,
|
|
481
|
+
all_tensors_to_one_file=True,
|
|
482
|
+
location=os.path.basename(output_path) + ".data",
|
|
483
|
+
size_threshold=0,
|
|
484
|
+
convert_attribute=False,
|
|
485
|
+
)
|
|
486
|
+
else:
|
|
487
|
+
logger.info(f"Skip optimizing: existing ONNX model {onnx_path}")
|
|
488
|
+
else:
|
|
489
|
+
output_path = onnx_path
|
|
490
|
+
|
|
491
|
+
output_paths.append(output_path)
|
|
492
|
+
|
|
493
|
+
return output_paths
|
|
494
|
+
|
|
495
|
+
|
|
496
|
+
def main(argv=None):
|
|
497
|
+
warnings.warn(
|
|
498
|
+
"This example is deprecated. Use the Olive recipe instead: "
|
|
499
|
+
"https://github.com/microsoft/olive-recipes/tree/main",
|
|
500
|
+
DeprecationWarning,
|
|
501
|
+
stacklevel=2,
|
|
502
|
+
)
|
|
503
|
+
args = parse_arguments(argv)
|
|
504
|
+
|
|
505
|
+
setup_logger(args.verbose)
|
|
506
|
+
|
|
507
|
+
logger.info(f"Arguments:{args}")
|
|
508
|
+
|
|
509
|
+
cache_dir = args.cache_dir
|
|
510
|
+
output_dir = args.output if not args.output.endswith(".onnx") else os.path.dirname(args.output)
|
|
511
|
+
prepare_environment(cache_dir, output_dir, args.use_gpu)
|
|
512
|
+
|
|
513
|
+
if args.precision == Precision.FLOAT16:
|
|
514
|
+
assert args.use_gpu, "fp16 requires --use_gpu"
|
|
515
|
+
|
|
516
|
+
output_paths = export_onnx_models(
|
|
517
|
+
args.model_name_or_path,
|
|
518
|
+
args.model_impl,
|
|
519
|
+
cache_dir,
|
|
520
|
+
output_dir,
|
|
521
|
+
args.use_gpu,
|
|
522
|
+
args.use_external_data_format,
|
|
523
|
+
args.optimize_onnx,
|
|
524
|
+
args.precision,
|
|
525
|
+
args.verbose,
|
|
526
|
+
args.use_forced_decoder_ids,
|
|
527
|
+
not args.separate_encoder_and_decoder_init,
|
|
528
|
+
args.no_beam_search_op,
|
|
529
|
+
args.use_decoder_masked_mha,
|
|
530
|
+
args.output_cross_qk,
|
|
531
|
+
args.overwrite,
|
|
532
|
+
not args.use_int64_inputs,
|
|
533
|
+
args.accuracy_level,
|
|
534
|
+
args.quantize_symmetric,
|
|
535
|
+
args.provider,
|
|
536
|
+
)
|
|
537
|
+
|
|
538
|
+
max_diff = 0
|
|
539
|
+
if not args.no_beam_search_op:
|
|
540
|
+
logger.info("Chaining model ... :")
|
|
541
|
+
args.beam_model_output_dir = WhisperHelper.get_onnx_path(
|
|
542
|
+
output_dir,
|
|
543
|
+
args.model_name_or_path,
|
|
544
|
+
suffix="_beamsearch",
|
|
545
|
+
new_folder=False,
|
|
546
|
+
)
|
|
547
|
+
for path in output_paths:
|
|
548
|
+
if "encoder_decoder" in path or "encoder" in path:
|
|
549
|
+
args.encoder_path = path
|
|
550
|
+
elif "decoder" in path:
|
|
551
|
+
args.decoder_path = path
|
|
552
|
+
chain_model(args)
|
|
553
|
+
output_paths.append(args.beam_model_output_dir)
|
|
554
|
+
|
|
555
|
+
# Check chained model
|
|
556
|
+
ort_session = create_onnxruntime_session(
|
|
557
|
+
args.beam_model_output_dir,
|
|
558
|
+
use_gpu=args.use_gpu,
|
|
559
|
+
provider=args.provider,
|
|
560
|
+
)
|
|
561
|
+
device = torch.device("cuda" if args.use_gpu else "cpu")
|
|
562
|
+
|
|
563
|
+
# Wrap parity check in try-except to allow export to continue in case this produces an error
|
|
564
|
+
try:
|
|
565
|
+
with torch.no_grad():
|
|
566
|
+
# Verify batched decoding with prompts for OpenAI implementation
|
|
567
|
+
if args.model_impl == "openai" and args.use_forced_decoder_ids:
|
|
568
|
+
max_diff = WhisperHelper.verify_onnx(
|
|
569
|
+
args.model_name_or_path, cache_dir, ort_session, device, batch_size=2, prompt_mode=True
|
|
570
|
+
)
|
|
571
|
+
else:
|
|
572
|
+
max_diff = WhisperHelper.verify_onnx(args.model_name_or_path, cache_dir, ort_session, device)
|
|
573
|
+
if max_diff > 1e-4:
|
|
574
|
+
logger.warning("PyTorch and ONNX Runtime results are NOT close")
|
|
575
|
+
else:
|
|
576
|
+
logger.info("PyTorch and ONNX Runtime results are close")
|
|
577
|
+
except Exception as e:
|
|
578
|
+
logger.warning(
|
|
579
|
+
f"An error occurred while trying to verify parity between PyTorch and ONNX Runtime: {e}", exc_info=True
|
|
580
|
+
)
|
|
581
|
+
|
|
582
|
+
# Remove extra ONNX models saved in output directory
|
|
583
|
+
for _file in os.listdir(output_dir):
|
|
584
|
+
if "_beamsearch" not in _file and "_jump_times" not in _file:
|
|
585
|
+
path = os.path.join(output_dir, _file)
|
|
586
|
+
os.remove(path)
|
|
587
|
+
if path in output_paths:
|
|
588
|
+
output_paths.remove(path)
|
|
589
|
+
|
|
590
|
+
else:
|
|
591
|
+
# Create ancillary JSON files for ONNX Runtime GenAI and/or Hugging Face's Optimum
|
|
592
|
+
WhisperHelper.save_processing(
|
|
593
|
+
args.model_name_or_path,
|
|
594
|
+
args.provider,
|
|
595
|
+
args.separate_encoder_and_decoder_init,
|
|
596
|
+
args.use_decoder_masked_mha,
|
|
597
|
+
args.output_cross_qk,
|
|
598
|
+
next(iter(filter(lambda path: "encoder" in path, output_paths))),
|
|
599
|
+
next(iter(filter(lambda path: "decoder" in path, output_paths))),
|
|
600
|
+
output_dir,
|
|
601
|
+
cache_dir,
|
|
602
|
+
)
|
|
603
|
+
|
|
604
|
+
logger.info(f"Done! Outputs: {output_paths}")
|
|
605
|
+
return max_diff
|
|
606
|
+
|
|
607
|
+
|
|
608
|
+
if __name__ == "__main__":
|
|
609
|
+
main()
|