onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxruntime/LICENSE +21 -0
- onnxruntime/Privacy.md +21 -0
- onnxruntime/ThirdPartyNotices.txt +6121 -0
- onnxruntime/__init__.py +418 -0
- onnxruntime/backend/__init__.py +6 -0
- onnxruntime/backend/backend.py +175 -0
- onnxruntime/backend/backend_rep.py +52 -0
- onnxruntime/capi/DirectML.dll +0 -0
- onnxruntime/capi/__init__.py +4 -0
- onnxruntime/capi/_ld_preload.py +7 -0
- onnxruntime/capi/_pybind_state.py +33 -0
- onnxruntime/capi/build_and_package_info.py +2 -0
- onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
- onnxruntime/capi/onnxruntime.dll +0 -0
- onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
- onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
- onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
- onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
- onnxruntime/capi/onnxruntime_validation.py +154 -0
- onnxruntime/capi/version_info.py +2 -0
- onnxruntime/datasets/__init__.py +18 -0
- onnxruntime/datasets/logreg_iris.onnx +0 -0
- onnxruntime/datasets/mul_1.onnx +0 -0
- onnxruntime/datasets/sigmoid.onnx +13 -0
- onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
- onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
- onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
- onnxruntime/quantization/__init__.py +19 -0
- onnxruntime/quantization/base_quantizer.py +529 -0
- onnxruntime/quantization/calibrate.py +1267 -0
- onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
- onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
- onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
- onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
- onnxruntime/quantization/fusions/__init__.py +4 -0
- onnxruntime/quantization/fusions/fusion.py +311 -0
- onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
- onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
- onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
- onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
- onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
- onnxruntime/quantization/neural_compressor/__init__.py +1 -0
- onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
- onnxruntime/quantization/neural_compressor/util.py +80 -0
- onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
- onnxruntime/quantization/onnx_model.py +600 -0
- onnxruntime/quantization/onnx_quantizer.py +1163 -0
- onnxruntime/quantization/operators/__init__.py +2 -0
- onnxruntime/quantization/operators/activation.py +119 -0
- onnxruntime/quantization/operators/argmax.py +18 -0
- onnxruntime/quantization/operators/attention.py +73 -0
- onnxruntime/quantization/operators/base_operator.py +26 -0
- onnxruntime/quantization/operators/binary_op.py +72 -0
- onnxruntime/quantization/operators/concat.py +62 -0
- onnxruntime/quantization/operators/conv.py +260 -0
- onnxruntime/quantization/operators/direct_q8.py +78 -0
- onnxruntime/quantization/operators/embed_layernorm.py +121 -0
- onnxruntime/quantization/operators/gather.py +64 -0
- onnxruntime/quantization/operators/gavgpool.py +62 -0
- onnxruntime/quantization/operators/gemm.py +172 -0
- onnxruntime/quantization/operators/lstm.py +121 -0
- onnxruntime/quantization/operators/matmul.py +231 -0
- onnxruntime/quantization/operators/maxpool.py +34 -0
- onnxruntime/quantization/operators/norm.py +40 -0
- onnxruntime/quantization/operators/pad.py +172 -0
- onnxruntime/quantization/operators/pooling.py +67 -0
- onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
- onnxruntime/quantization/operators/resize.py +34 -0
- onnxruntime/quantization/operators/softmax.py +74 -0
- onnxruntime/quantization/operators/split.py +63 -0
- onnxruntime/quantization/operators/where.py +87 -0
- onnxruntime/quantization/preprocess.py +141 -0
- onnxruntime/quantization/qdq_loss_debug.py +389 -0
- onnxruntime/quantization/qdq_quantizer.py +1477 -0
- onnxruntime/quantization/quant_utils.py +1051 -0
- onnxruntime/quantization/quantize.py +953 -0
- onnxruntime/quantization/registry.py +110 -0
- onnxruntime/quantization/shape_inference.py +204 -0
- onnxruntime/quantization/static_quantize_runner.py +256 -0
- onnxruntime/quantization/tensor_quant_overrides.py +520 -0
- onnxruntime/tools/__init__.py +10 -0
- onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
- onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
- onnxruntime/tools/file_utils.py +47 -0
- onnxruntime/tools/logger.py +11 -0
- onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
- onnxruntime/tools/mobile_helpers/__init__.py +0 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
- onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
- onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
- onnxruntime/tools/offline_tuning.py +169 -0
- onnxruntime/tools/onnx_model_utils.py +416 -0
- onnxruntime/tools/onnx_randomizer.py +85 -0
- onnxruntime/tools/onnxruntime_test.py +164 -0
- onnxruntime/tools/optimize_onnx_model.py +56 -0
- onnxruntime/tools/ort_format_model/__init__.py +27 -0
- onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
- onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
- onnxruntime/tools/ort_format_model/types.py +85 -0
- onnxruntime/tools/ort_format_model/utils.py +61 -0
- onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
- onnxruntime/tools/pytorch_export_helpers.py +131 -0
- onnxruntime/tools/qdq_helpers/__init__.py +0 -0
- onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
- onnxruntime/tools/qnn/add_trans_cast.py +292 -0
- onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
- onnxruntime/tools/qnn/preprocess.py +165 -0
- onnxruntime/tools/reduced_build_config_parser.py +203 -0
- onnxruntime/tools/remove_initializer_from_input.py +37 -0
- onnxruntime/tools/symbolic_shape_infer.py +3094 -0
- onnxruntime/tools/update_onnx_opset.py +31 -0
- onnxruntime/transformers/__init__.py +8 -0
- onnxruntime/transformers/affinity_helper.py +40 -0
- onnxruntime/transformers/benchmark.py +942 -0
- onnxruntime/transformers/benchmark_helper.py +643 -0
- onnxruntime/transformers/bert_perf_test.py +629 -0
- onnxruntime/transformers/bert_test_data.py +641 -0
- onnxruntime/transformers/compare_bert_results.py +256 -0
- onnxruntime/transformers/constants.py +47 -0
- onnxruntime/transformers/convert_generation.py +3605 -0
- onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
- onnxruntime/transformers/convert_to_packing_mode.py +385 -0
- onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
- onnxruntime/transformers/float16.py +501 -0
- onnxruntime/transformers/fusion_attention.py +1189 -0
- onnxruntime/transformers/fusion_attention_clip.py +340 -0
- onnxruntime/transformers/fusion_attention_sam2.py +533 -0
- onnxruntime/transformers/fusion_attention_unet.py +1307 -0
- onnxruntime/transformers/fusion_attention_vae.py +300 -0
- onnxruntime/transformers/fusion_bart_attention.py +435 -0
- onnxruntime/transformers/fusion_base.py +141 -0
- onnxruntime/transformers/fusion_bias_add.py +57 -0
- onnxruntime/transformers/fusion_biasgelu.py +66 -0
- onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
- onnxruntime/transformers/fusion_conformer_attention.py +222 -0
- onnxruntime/transformers/fusion_constant_fold.py +144 -0
- onnxruntime/transformers/fusion_embedlayer.py +810 -0
- onnxruntime/transformers/fusion_fastgelu.py +492 -0
- onnxruntime/transformers/fusion_gelu.py +258 -0
- onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
- onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
- onnxruntime/transformers/fusion_gpt_attention.py +546 -0
- onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
- onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
- onnxruntime/transformers/fusion_group_norm.py +180 -0
- onnxruntime/transformers/fusion_layernorm.py +489 -0
- onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
- onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
- onnxruntime/transformers/fusion_options.py +340 -0
- onnxruntime/transformers/fusion_qordered_attention.py +420 -0
- onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
- onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
- onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
- onnxruntime/transformers/fusion_quickgelu.py +74 -0
- onnxruntime/transformers/fusion_reshape.py +173 -0
- onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
- onnxruntime/transformers/fusion_shape.py +109 -0
- onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
- onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
- onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
- onnxruntime/transformers/fusion_transpose.py +167 -0
- onnxruntime/transformers/fusion_utils.py +321 -0
- onnxruntime/transformers/huggingface_models.py +74 -0
- onnxruntime/transformers/import_utils.py +20 -0
- onnxruntime/transformers/io_binding_helper.py +487 -0
- onnxruntime/transformers/large_model_exporter.py +395 -0
- onnxruntime/transformers/machine_info.py +230 -0
- onnxruntime/transformers/metrics.py +163 -0
- onnxruntime/transformers/models/bart/__init__.py +12 -0
- onnxruntime/transformers/models/bart/export.py +98 -0
- onnxruntime/transformers/models/bert/__init__.py +12 -0
- onnxruntime/transformers/models/bert/eval_squad.py +329 -0
- onnxruntime/transformers/models/gpt2/__init__.py +12 -0
- onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
- onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
- onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
- onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
- onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
- onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
- onnxruntime/transformers/models/llama/__init__.py +12 -0
- onnxruntime/transformers/models/llama/benchmark.py +700 -0
- onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
- onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
- onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
- onnxruntime/transformers/models/llama/dist_settings.py +57 -0
- onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
- onnxruntime/transformers/models/llama/llama_parity.py +343 -0
- onnxruntime/transformers/models/llama/llama_torch.py +47 -0
- onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
- onnxruntime/transformers/models/longformer/__init__.py +12 -0
- onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
- onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
- onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
- onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
- onnxruntime/transformers/models/phi2/__init__.py +12 -0
- onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
- onnxruntime/transformers/models/phi2/inference_example.py +414 -0
- onnxruntime/transformers/models/sam2/__init__.py +12 -0
- onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
- onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
- onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
- onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
- onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
- onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
- onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
- onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
- onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
- onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
- onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
- onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
- onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
- onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
- onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
- onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
- onnxruntime/transformers/models/t5/__init__.py +12 -0
- onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
- onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
- onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
- onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
- onnxruntime/transformers/models/t5/t5_helper.py +302 -0
- onnxruntime/transformers/models/whisper/__init__.py +12 -0
- onnxruntime/transformers/models/whisper/benchmark.py +585 -0
- onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
- onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
- onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
- onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
- onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
- onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
- onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
- onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
- onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
- onnxruntime/transformers/onnx_exporter.py +719 -0
- onnxruntime/transformers/onnx_model.py +1636 -0
- onnxruntime/transformers/onnx_model_bart.py +141 -0
- onnxruntime/transformers/onnx_model_bert.py +488 -0
- onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
- onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
- onnxruntime/transformers/onnx_model_clip.py +42 -0
- onnxruntime/transformers/onnx_model_conformer.py +32 -0
- onnxruntime/transformers/onnx_model_gpt2.py +101 -0
- onnxruntime/transformers/onnx_model_mmdit.py +112 -0
- onnxruntime/transformers/onnx_model_phi.py +929 -0
- onnxruntime/transformers/onnx_model_sam2.py +137 -0
- onnxruntime/transformers/onnx_model_t5.py +985 -0
- onnxruntime/transformers/onnx_model_tnlr.py +226 -0
- onnxruntime/transformers/onnx_model_unet.py +258 -0
- onnxruntime/transformers/onnx_model_vae.py +42 -0
- onnxruntime/transformers/onnx_utils.py +55 -0
- onnxruntime/transformers/optimizer.py +620 -0
- onnxruntime/transformers/past_helper.py +149 -0
- onnxruntime/transformers/profile_result_processor.py +358 -0
- onnxruntime/transformers/profiler.py +434 -0
- onnxruntime/transformers/quantize_helper.py +76 -0
- onnxruntime/transformers/shape_infer_helper.py +121 -0
- onnxruntime/transformers/shape_optimizer.py +400 -0
- onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
- onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
- onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
- onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
- onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
- onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1064 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License. See License.txt in the project root for
|
|
4
|
+
# license information.
|
|
5
|
+
# --------------------------------------------------------------------------
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
import argparse
|
|
9
|
+
import logging
|
|
10
|
+
import os
|
|
11
|
+
import shutil
|
|
12
|
+
import subprocess
|
|
13
|
+
import sys
|
|
14
|
+
import tempfile
|
|
15
|
+
import warnings
|
|
16
|
+
from itertools import chain
|
|
17
|
+
|
|
18
|
+
import onnx
|
|
19
|
+
import torch
|
|
20
|
+
from benchmark_helper import Precision, prepare_environment, setup_logger
|
|
21
|
+
from convert_generation import replace_mha_with_gqa
|
|
22
|
+
from dist_settings import barrier, get_rank, get_size, init_dist
|
|
23
|
+
from llama_inputs import get_merged_sample_with_past_kv_inputs, get_sample_inputs, get_sample_with_past_kv_inputs
|
|
24
|
+
from llama_parity import main as parity_check
|
|
25
|
+
from llama_torch import setup_torch_model
|
|
26
|
+
|
|
27
|
+
# to patch transformers before exporting for transformers >= 4.45
|
|
28
|
+
from models.torch_export_patches import bypass_export_some_errors
|
|
29
|
+
from models.torch_export_patches.patch_inputs import convert_dynamic_axes_into_dynamic_shapes
|
|
30
|
+
from onnx_model import OnnxModel
|
|
31
|
+
from optimizer import optimize_model
|
|
32
|
+
from packaging import version
|
|
33
|
+
from transformers import AutoConfig, AutoModelForCausalLM
|
|
34
|
+
|
|
35
|
+
from onnxruntime import __version__ as ort_version
|
|
36
|
+
from onnxruntime import quantization as ort_quantization
|
|
37
|
+
|
|
38
|
+
if version.parse(ort_version) < version.parse("1.22.0"):
|
|
39
|
+
from onnxruntime.quantization.matmul_4bits_quantizer import MatMul4BitsQuantizer as MatMulNBitsQuantizer
|
|
40
|
+
else:
|
|
41
|
+
from onnxruntime.quantization.matmul_nbits_quantizer import MatMulNBitsQuantizer
|
|
42
|
+
|
|
43
|
+
torch_export_onnx_opset_version = 14
|
|
44
|
+
logger = logging.getLogger("")
|
|
45
|
+
init_dist()
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def get_model_dynamic_axes(input_names: list[str], output_names: list[str]):
|
|
49
|
+
dynamic_axes = {}
|
|
50
|
+
for name in input_names + output_names:
|
|
51
|
+
if name in input_names:
|
|
52
|
+
# shape is (batch_size, sequence_length)
|
|
53
|
+
dynamic_axes[name] = {0: "batch_size", 1: "sequence_length"}
|
|
54
|
+
elif name == "logits":
|
|
55
|
+
# shape is (batch_size, sequence_length, vocab_size)
|
|
56
|
+
dynamic_axes[name] = {0: "batch_size", 1: "sequence_length"}
|
|
57
|
+
elif "present" in name:
|
|
58
|
+
# shape is (batch_size, num_heads, sequence_length, head_size)
|
|
59
|
+
dynamic_axes[name] = {0: "batch_size", 2: "sequence_length"}
|
|
60
|
+
else:
|
|
61
|
+
raise Exception("Unknown input or output name found")
|
|
62
|
+
return dynamic_axes
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def get_model_with_past_kv_dynamic_axes(input_names: list[str], output_names: list[str]):
|
|
66
|
+
dynamic_axes = {}
|
|
67
|
+
for name in input_names + output_names:
|
|
68
|
+
if name in {"input_ids", "position_ids"}:
|
|
69
|
+
# shape is (batch_size, 1)
|
|
70
|
+
dynamic_axes[name] = {0: "batch_size"}
|
|
71
|
+
elif name == "attention_mask":
|
|
72
|
+
# shape is (batch_size, past_sequence_length + 1)
|
|
73
|
+
dynamic_axes[name] = {0: "batch_size", 1: "past_sequence_length + 1"}
|
|
74
|
+
elif "past" in name:
|
|
75
|
+
# shape is (batch_size, num_heads, past_sequence_length, head_size)
|
|
76
|
+
dynamic_axes[name] = {0: "batch_size", 2: "past_sequence_length"}
|
|
77
|
+
elif name == "logits":
|
|
78
|
+
# shape is (batch_size, 1, vocab_size)
|
|
79
|
+
dynamic_axes[name] = {0: "batch_size"}
|
|
80
|
+
elif "present" in name:
|
|
81
|
+
# shape is (batch_size, num_heads, past_sequence_length + 1, head_size)
|
|
82
|
+
dynamic_axes[name] = {0: "batch_size", 2: "past_sequence_length + 1"}
|
|
83
|
+
else:
|
|
84
|
+
raise Exception("Unknown input or output name found")
|
|
85
|
+
return dynamic_axes
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def get_merged_model_dynamic_axes(input_names: list[str], output_names: list[str]):
|
|
89
|
+
dynamic_axes = {}
|
|
90
|
+
for name in input_names + output_names:
|
|
91
|
+
if name in {"input_ids", "position_ids"}:
|
|
92
|
+
# shape is (batch_size, sequence_length)
|
|
93
|
+
dynamic_axes[name] = {0: "batch_size", 1: "sequence_length"}
|
|
94
|
+
elif name == "attention_mask":
|
|
95
|
+
# shape is (batch_size, past_sequence_length + sequence_length) = (batch_size, total_sequence_length)
|
|
96
|
+
# for prompt generation, past_sequence_length = 0
|
|
97
|
+
# for token generation, sequence_length = 1
|
|
98
|
+
dynamic_axes[name] = {0: "batch_size", 1: "total_sequence_length"}
|
|
99
|
+
elif "past" in name:
|
|
100
|
+
# shape is (batch_size, num_heads, past_sequence_length, head_size)
|
|
101
|
+
dynamic_axes[name] = {0: "batch_size", 2: "past_sequence_length"}
|
|
102
|
+
elif name == "logits":
|
|
103
|
+
# shape is (batch_size, sequence_length, vocab_size)
|
|
104
|
+
dynamic_axes[name] = {0: "batch_size", 1: "sequence_length"}
|
|
105
|
+
elif "present" in name:
|
|
106
|
+
# shape is (batch_size, num_heads, past_sequence_length + sequence_length, head_size) = (batch_size, num_heads, total_sequence_length, head_size)
|
|
107
|
+
# for prompt generation, past_sequence_length = 0
|
|
108
|
+
# for token generation, sequence_length = 1
|
|
109
|
+
dynamic_axes[name] = {0: "batch_size", 2: "total_sequence_length"}
|
|
110
|
+
else:
|
|
111
|
+
raise Exception("Unknown input or output name found")
|
|
112
|
+
return dynamic_axes
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def save_onnx_model(onnx_model: onnx.ModelProto, output_path: str, data_path: str):
|
|
116
|
+
onnx.save(
|
|
117
|
+
onnx_model,
|
|
118
|
+
output_path,
|
|
119
|
+
save_as_external_data=True,
|
|
120
|
+
all_tensors_to_one_file=True,
|
|
121
|
+
location=data_path,
|
|
122
|
+
size_threshold=1024,
|
|
123
|
+
convert_attribute=False,
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def run_dynamo_export(
|
|
128
|
+
args: argparse.Namespace, l_config: AutoConfig, llama: AutoModelForCausalLM, rank: int = 0, world_size: int = 1
|
|
129
|
+
):
|
|
130
|
+
from torch._dynamo import config # noqa: PLC0415
|
|
131
|
+
|
|
132
|
+
config.capture_scalar_outputs = True
|
|
133
|
+
|
|
134
|
+
# Dummy values for export
|
|
135
|
+
batch_size, sequence_length, past_sequence_length = 2, 8, 3
|
|
136
|
+
device = llama.device if args.model_name == "Llama-2-70b-hf" else torch.device("cpu")
|
|
137
|
+
|
|
138
|
+
temp_name = args.model_name.lower().replace("-", "").replace("_", "")
|
|
139
|
+
max_sequence_length = 16384 if "codellama" in temp_name else 4096 if "llama2" in temp_name else 2048
|
|
140
|
+
|
|
141
|
+
# Export decoder_with_past_model.onnx
|
|
142
|
+
input_ids, attn_mask, pos_ids, past_kv = get_merged_sample_with_past_kv_inputs(
|
|
143
|
+
l_config,
|
|
144
|
+
device,
|
|
145
|
+
batch_size,
|
|
146
|
+
sequence_length,
|
|
147
|
+
past_sequence_length,
|
|
148
|
+
max_seq_len=max_sequence_length,
|
|
149
|
+
use_fp16=False,
|
|
150
|
+
world_size=world_size,
|
|
151
|
+
)
|
|
152
|
+
temp_dir = tempfile.TemporaryDirectory()
|
|
153
|
+
temp_path = os.path.join(temp_dir.name, "temp.onnx")
|
|
154
|
+
|
|
155
|
+
input_names = ["input_ids", "attention_mask", "position_ids"]
|
|
156
|
+
output_names = [
|
|
157
|
+
"logits",
|
|
158
|
+
*list(
|
|
159
|
+
chain.from_iterable((f"present.{i}.key", f"present.{i}.value") for i in range(l_config.num_hidden_layers))
|
|
160
|
+
),
|
|
161
|
+
]
|
|
162
|
+
dynamic_axes = get_model_dynamic_axes(input_names, output_names)
|
|
163
|
+
|
|
164
|
+
model_args = (input_ids, attn_mask, pos_ids, past_kv)
|
|
165
|
+
model_args, model_kwargs, dynamic_shapes = convert_dynamic_axes_into_dynamic_shapes(
|
|
166
|
+
llama, args=model_args, dynamic_axes=dynamic_axes, prefix_mapping={"present": "past_key_values"}
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
with bypass_export_some_errors(patch_transformers=True):
|
|
170
|
+
torch.onnx.export(
|
|
171
|
+
llama,
|
|
172
|
+
(),
|
|
173
|
+
temp_path,
|
|
174
|
+
kwargs=model_kwargs,
|
|
175
|
+
dynamic_shapes=dynamic_shapes,
|
|
176
|
+
dynamo=True,
|
|
177
|
+
verbose=args.verbose,
|
|
178
|
+
optimize=True,
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
# Check decoder_with_past_model.onnx and save all external data to one file
|
|
182
|
+
onnx.checker.check_model(temp_path)
|
|
183
|
+
onnx.shape_inference.infer_shapes_path(temp_path)
|
|
184
|
+
|
|
185
|
+
output_path = os.path.join(args.output, f"rank_{rank}_{args.model_name}_decoder_with_past_model_fp32.onnx")
|
|
186
|
+
onnx_model = onnx.load_model(temp_path, load_external_data=True)
|
|
187
|
+
save_onnx_model(onnx_model, output_path, f"rank_{rank}_{args.model_name}_decoder_with_past_model_fp32.onnx.data")
|
|
188
|
+
del onnx_model
|
|
189
|
+
temp_dir.cleanup()
|
|
190
|
+
|
|
191
|
+
logger.info(f"The {args.model_name} ONNX model has been successfully created with the Dynamo exporter!")
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
def _prepare_dir(dir_path):
|
|
195
|
+
if not os.path.exists(dir_path):
|
|
196
|
+
os.makedirs(dir_path)
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
def run_torchscript_separate_export(
|
|
200
|
+
args: argparse.Namespace, l_config: AutoConfig, llama: AutoModelForCausalLM, rank: int = 0, world_size: int = 1
|
|
201
|
+
):
|
|
202
|
+
# Dummy values for export
|
|
203
|
+
batch_size, sequence_length = 2, 8
|
|
204
|
+
|
|
205
|
+
# set device used to export model
|
|
206
|
+
# for llama-2-70b we will use current gpus to speed up export process
|
|
207
|
+
# for other models, we will use CPU to make sure we have enough memory to do export
|
|
208
|
+
device = llama.device if args.model_name == "Llama-2-70b-hf" else torch.device("cpu")
|
|
209
|
+
|
|
210
|
+
# Export decoder_model.onnx
|
|
211
|
+
decoder_inputs = get_sample_inputs(l_config, device, batch_size, sequence_length)
|
|
212
|
+
|
|
213
|
+
input_names = ["input_ids", "attention_mask", "position_ids"]
|
|
214
|
+
output_names = [
|
|
215
|
+
"logits",
|
|
216
|
+
*list(
|
|
217
|
+
chain.from_iterable((f"present.{i}.key", f"present.{i}.value") for i in range(l_config.num_hidden_layers))
|
|
218
|
+
),
|
|
219
|
+
]
|
|
220
|
+
dynamic_axes = get_model_dynamic_axes(input_names, output_names)
|
|
221
|
+
|
|
222
|
+
# Avoid using system temp dir to avoid overflood on hard disk as 70b model is very large.
|
|
223
|
+
# Use temp folder per rank to avoid race condition here.
|
|
224
|
+
temp_dir = f"./temp_{rank}"
|
|
225
|
+
_prepare_dir(temp_dir)
|
|
226
|
+
temp_path = os.path.join(temp_dir, "temp.onnx")
|
|
227
|
+
torch.onnx.export(
|
|
228
|
+
llama,
|
|
229
|
+
args=decoder_inputs,
|
|
230
|
+
f=temp_path,
|
|
231
|
+
export_params=True,
|
|
232
|
+
input_names=input_names,
|
|
233
|
+
output_names=output_names,
|
|
234
|
+
dynamic_axes=dynamic_axes,
|
|
235
|
+
opset_version=torch_export_onnx_opset_version,
|
|
236
|
+
do_constant_folding=True,
|
|
237
|
+
verbose=args.verbose,
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
# Check decoder_model.onnx and save all external data to one file
|
|
241
|
+
onnx.checker.check_model(temp_path)
|
|
242
|
+
onnx.shape_inference.infer_shapes_path(temp_path)
|
|
243
|
+
|
|
244
|
+
output_path = os.path.join(args.output, f"rank_{rank}_{args.model_name}_decoder_model_fp32.onnx")
|
|
245
|
+
onnx_model = onnx.load_model(temp_path, load_external_data=True)
|
|
246
|
+
save_onnx_model(
|
|
247
|
+
onnx_model,
|
|
248
|
+
output_path,
|
|
249
|
+
f"rank_{rank}_{args.model_name}_decoder_model_fp32.onnx.data",
|
|
250
|
+
)
|
|
251
|
+
del onnx_model
|
|
252
|
+
shutil.rmtree(temp_dir)
|
|
253
|
+
|
|
254
|
+
# Export decoder_with_past_model.onnx
|
|
255
|
+
decoder_with_past_inputs = get_sample_with_past_kv_inputs(
|
|
256
|
+
l_config,
|
|
257
|
+
device,
|
|
258
|
+
batch_size,
|
|
259
|
+
sequence_length,
|
|
260
|
+
use_fp16=False,
|
|
261
|
+
world_size=world_size,
|
|
262
|
+
)
|
|
263
|
+
input_names = [
|
|
264
|
+
"input_ids",
|
|
265
|
+
"attention_mask",
|
|
266
|
+
"position_ids",
|
|
267
|
+
*list(
|
|
268
|
+
chain.from_iterable(
|
|
269
|
+
(f"past_key_values.{i}.key", f"past_key_values.{i}.value") for i in range(l_config.num_hidden_layers)
|
|
270
|
+
)
|
|
271
|
+
),
|
|
272
|
+
]
|
|
273
|
+
output_names = [
|
|
274
|
+
"logits",
|
|
275
|
+
*list(
|
|
276
|
+
chain.from_iterable((f"present.{i}.key", f"present.{i}.value") for i in range(l_config.num_hidden_layers))
|
|
277
|
+
),
|
|
278
|
+
]
|
|
279
|
+
dynamic_axes = get_model_with_past_kv_dynamic_axes(input_names, output_names)
|
|
280
|
+
|
|
281
|
+
# Avoid using system temp dir to avoid overflood on hard disk as 70b model is very large.
|
|
282
|
+
# Use temp folder per rank to avoid race condition here.
|
|
283
|
+
temp_dir = f"./temp_past_{rank}"
|
|
284
|
+
_prepare_dir(temp_dir)
|
|
285
|
+
temp_path = os.path.join(temp_dir, "temp.onnx")
|
|
286
|
+
torch.onnx.export(
|
|
287
|
+
llama,
|
|
288
|
+
args=decoder_with_past_inputs,
|
|
289
|
+
f=temp_path,
|
|
290
|
+
export_params=True,
|
|
291
|
+
input_names=input_names,
|
|
292
|
+
output_names=output_names,
|
|
293
|
+
dynamic_axes=dynamic_axes,
|
|
294
|
+
opset_version=torch_export_onnx_opset_version,
|
|
295
|
+
do_constant_folding=True,
|
|
296
|
+
verbose=args.verbose,
|
|
297
|
+
)
|
|
298
|
+
|
|
299
|
+
# Check decoder_with_past_model.onnx and save all external data to one file
|
|
300
|
+
onnx.checker.check_model(temp_path)
|
|
301
|
+
onnx.shape_inference.infer_shapes_path(temp_path)
|
|
302
|
+
|
|
303
|
+
output_path = os.path.join(args.output, f"rank_{rank}_{args.model_name}_decoder_with_past_model_fp32.onnx")
|
|
304
|
+
onnx_model = onnx.load_model(temp_path, load_external_data=True)
|
|
305
|
+
save_onnx_model(
|
|
306
|
+
onnx_model,
|
|
307
|
+
output_path,
|
|
308
|
+
f"rank_{rank}_{args.model_name}_decoder_with_past_model_fp32.onnx.data",
|
|
309
|
+
)
|
|
310
|
+
del onnx_model
|
|
311
|
+
shutil.rmtree(temp_dir)
|
|
312
|
+
|
|
313
|
+
logger.info(
|
|
314
|
+
f"The {args.model_name} separate ONNX model has been successfully created with the TorchScript exporter!"
|
|
315
|
+
)
|
|
316
|
+
|
|
317
|
+
|
|
318
|
+
def run_torchscript_merged_export(
|
|
319
|
+
args: argparse.Namespace, l_config: AutoConfig, llama: AutoModelForCausalLM, rank: int = 0, world_size: int = 1
|
|
320
|
+
):
|
|
321
|
+
# Dummy values for export
|
|
322
|
+
batch_size, sequence_length, past_sequence_length = 2, 8, 0
|
|
323
|
+
|
|
324
|
+
# set device used to export model
|
|
325
|
+
# for llama-2-70b we will use current gpus to speed up export process
|
|
326
|
+
# for other models, we will use CPU to make sure we have enough memory to do export
|
|
327
|
+
device = llama.device if args.model_name == "Llama-2-70b-hf" else torch.device("cpu")
|
|
328
|
+
|
|
329
|
+
temp_name = args.model_name.lower().replace("-", "").replace("_", "")
|
|
330
|
+
max_sequence_length = 16384 if "codellama" in temp_name else 4096 if "llama2" in temp_name else 2048
|
|
331
|
+
|
|
332
|
+
# Export decoder_merged_model.onnx
|
|
333
|
+
decoder_merged_inputs = get_merged_sample_with_past_kv_inputs(
|
|
334
|
+
l_config,
|
|
335
|
+
device,
|
|
336
|
+
batch_size,
|
|
337
|
+
sequence_length,
|
|
338
|
+
past_sequence_length,
|
|
339
|
+
max_seq_len=max_sequence_length,
|
|
340
|
+
use_fp16=False,
|
|
341
|
+
world_size=world_size,
|
|
342
|
+
)
|
|
343
|
+
input_names = [
|
|
344
|
+
"input_ids",
|
|
345
|
+
"attention_mask",
|
|
346
|
+
"position_ids",
|
|
347
|
+
*list(
|
|
348
|
+
chain.from_iterable(
|
|
349
|
+
(f"past_key_values.{i}.key", f"past_key_values.{i}.value") for i in range(l_config.num_hidden_layers)
|
|
350
|
+
)
|
|
351
|
+
),
|
|
352
|
+
]
|
|
353
|
+
output_names = [
|
|
354
|
+
"logits",
|
|
355
|
+
*list(
|
|
356
|
+
chain.from_iterable((f"present.{i}.key", f"present.{i}.value") for i in range(l_config.num_hidden_layers))
|
|
357
|
+
),
|
|
358
|
+
]
|
|
359
|
+
dynamic_axes = get_merged_model_dynamic_axes(input_names, output_names)
|
|
360
|
+
|
|
361
|
+
# Avoid using system temp dir to avoid overflood on hard disk as 70b model is very large.
|
|
362
|
+
# Use temp folder per rank to avoid race condition here.
|
|
363
|
+
temp_dir = f"./temp_{rank}"
|
|
364
|
+
_prepare_dir(temp_dir)
|
|
365
|
+
temp_path = os.path.join(temp_dir, "temp.onnx")
|
|
366
|
+
|
|
367
|
+
torch.onnx.export(
|
|
368
|
+
llama,
|
|
369
|
+
args=decoder_merged_inputs,
|
|
370
|
+
f=temp_path,
|
|
371
|
+
export_params=True,
|
|
372
|
+
input_names=input_names,
|
|
373
|
+
output_names=output_names,
|
|
374
|
+
dynamic_axes=dynamic_axes,
|
|
375
|
+
opset_version=torch_export_onnx_opset_version,
|
|
376
|
+
do_constant_folding=True,
|
|
377
|
+
verbose=args.verbose,
|
|
378
|
+
dynamo=False,
|
|
379
|
+
)
|
|
380
|
+
|
|
381
|
+
# Check decoder_merged_model.onnx and save all external data to one file
|
|
382
|
+
onnx.checker.check_model(temp_path)
|
|
383
|
+
onnx.shape_inference.infer_shapes_path(temp_path)
|
|
384
|
+
|
|
385
|
+
output_path = os.path.join(args.output, f"rank_{rank}_{args.model_name}_decoder_merged_model_fp32.onnx")
|
|
386
|
+
onnx_model = onnx.load_model(temp_path, load_external_data=True)
|
|
387
|
+
save_onnx_model(
|
|
388
|
+
onnx_model,
|
|
389
|
+
output_path,
|
|
390
|
+
f"rank_{rank}_{args.model_name}_decoder_merged_model_fp32.onnx.data",
|
|
391
|
+
)
|
|
392
|
+
del onnx_model
|
|
393
|
+
shutil.rmtree(temp_dir)
|
|
394
|
+
|
|
395
|
+
logger.info(f"The {args.model_name} merged ONNX model has been successfully created with the TorchScript exporter!")
|
|
396
|
+
|
|
397
|
+
|
|
398
|
+
# Optimize the model as FP32
|
|
399
|
+
def optimize_export(
|
|
400
|
+
args: argparse.Namespace,
|
|
401
|
+
config: AutoConfig,
|
|
402
|
+
input_path: str,
|
|
403
|
+
output_path: str,
|
|
404
|
+
remove_model: bool = True,
|
|
405
|
+
world_size: int = 1,
|
|
406
|
+
window_size: int = -1,
|
|
407
|
+
):
|
|
408
|
+
from fusion_options import FusionOptions # noqa: PLC0415
|
|
409
|
+
|
|
410
|
+
optimization_options = FusionOptions("gpt2")
|
|
411
|
+
|
|
412
|
+
model_opt = optimize_model(
|
|
413
|
+
input_path,
|
|
414
|
+
model_type="gpt2",
|
|
415
|
+
num_heads=config.num_attention_heads,
|
|
416
|
+
hidden_size=config.hidden_size,
|
|
417
|
+
opt_level=0,
|
|
418
|
+
optimization_options=optimization_options,
|
|
419
|
+
only_onnxruntime=False,
|
|
420
|
+
)
|
|
421
|
+
if args.use_gqa:
|
|
422
|
+
model_opt = use_group_query_attention(config, model_opt, world_size, window_size)
|
|
423
|
+
model_opt.save_model_to_file(output_path, use_external_data_format=True)
|
|
424
|
+
|
|
425
|
+
# Run symbolic shape inference on optimized model to avoid shape errors during runtime
|
|
426
|
+
# Ex: Before attention fusion, RotaryEmbedding assumes a 4D input and produces a 4D output.
|
|
427
|
+
# After attention fusion, RotaryEmbedding expects a 3D input and produces a 3D output.
|
|
428
|
+
wheel_cmd = [sys.executable, "-m", "onnxruntime.tools.symbolic_shape_infer"]
|
|
429
|
+
source_cmd = [sys.executable, "../symbolic_shape_infer.py"]
|
|
430
|
+
symbolic_shape_infer_args = [
|
|
431
|
+
"--input",
|
|
432
|
+
output_path,
|
|
433
|
+
"--output",
|
|
434
|
+
output_path,
|
|
435
|
+
"--auto_merge",
|
|
436
|
+
"--save_as_external_data",
|
|
437
|
+
"--all_tensors_to_one_file",
|
|
438
|
+
"--external_data_location",
|
|
439
|
+
os.path.basename(output_path) + ".data",
|
|
440
|
+
]
|
|
441
|
+
|
|
442
|
+
file_path = os.path.dirname(__file__)
|
|
443
|
+
if os.path.exists(os.path.join(file_path, "../../../tools/symbolic_shape_infer.py")):
|
|
444
|
+
main_cmd = wheel_cmd
|
|
445
|
+
else:
|
|
446
|
+
main_cmd = source_cmd
|
|
447
|
+
subprocess.run(main_cmd + symbolic_shape_infer_args) # noqa: PLW1510
|
|
448
|
+
|
|
449
|
+
logger.info(f"The ONNX model at {input_path} has been successfully optimized and saved at {output_path}!")
|
|
450
|
+
if remove_model:
|
|
451
|
+
remove_existing_model(input_path)
|
|
452
|
+
|
|
453
|
+
|
|
454
|
+
def convert_to_float16(args: argparse.Namespace, old_paths: list[str], rank: int = 0):
|
|
455
|
+
decoder_model_fp16_path = os.path.join(args.output, f"rank_{rank}_{args.model_name}_decoder_model_fp16.onnx")
|
|
456
|
+
decoder_with_past_model_fp16_path = os.path.join(
|
|
457
|
+
args.output, f"rank_{rank}_{args.model_name}_decoder_with_past_model_fp16.onnx"
|
|
458
|
+
)
|
|
459
|
+
decoder_merged_model_fp16_path = os.path.join(
|
|
460
|
+
args.output, f"rank_{rank}_{args.model_name}_decoder_merged_model_fp16.onnx"
|
|
461
|
+
)
|
|
462
|
+
new_paths = [decoder_model_fp16_path, decoder_with_past_model_fp16_path, decoder_merged_model_fp16_path]
|
|
463
|
+
|
|
464
|
+
logger.info("Converting to float16...")
|
|
465
|
+
for fp32_path, fp16_path in zip(old_paths, new_paths, strict=False):
|
|
466
|
+
if os.path.exists(fp32_path):
|
|
467
|
+
model = OnnxModel(onnx.load_model(fp32_path, load_external_data=True))
|
|
468
|
+
model.convert_float_to_float16(keep_io_types=False)
|
|
469
|
+
model.save_model_to_file(fp16_path, use_external_data_format=True)
|
|
470
|
+
del model
|
|
471
|
+
logger.info(f"The ONNX model at {fp32_path} has been converted to float16 and saved at {fp16_path}!")
|
|
472
|
+
remove_existing_model(fp32_path)
|
|
473
|
+
|
|
474
|
+
logger.info(f"The {args.model_name} ONNX model has been successfully converted to float16!")
|
|
475
|
+
return new_paths
|
|
476
|
+
|
|
477
|
+
|
|
478
|
+
def use_group_query_attention(config: AutoConfig, model_opt: OnnxModel, world_size: int = 1, window_size: int = -1):
|
|
479
|
+
# Replace MultiHeadAttention with GroupQueryAttention
|
|
480
|
+
model_opt = replace_mha_with_gqa(model_opt, "attention_mask", config.num_key_value_heads, world_size, window_size)
|
|
481
|
+
model_opt.prune_graph()
|
|
482
|
+
model_opt.update_graph(allow_remove_graph_inputs=True)
|
|
483
|
+
return model_opt
|
|
484
|
+
|
|
485
|
+
|
|
486
|
+
def smooth_quant(
|
|
487
|
+
args: argparse.Namespace,
|
|
488
|
+
decoder_model_fp32_path: str,
|
|
489
|
+
decoder_with_past_model_fp32_path: str,
|
|
490
|
+
decoder_model_int8_path: str,
|
|
491
|
+
decoder_with_past_model_int8_path: str,
|
|
492
|
+
):
|
|
493
|
+
from neural_compressor import PostTrainingQuantConfig, set_workspace # noqa: PLC0415
|
|
494
|
+
from neural_compressor import quantization as intel_quantization # noqa: PLC0415
|
|
495
|
+
from onnx.external_data_helper import load_external_data_for_model # noqa: PLC0415
|
|
496
|
+
from quant_kv_dataloader import QuantKVDataLoader # noqa: PLC0415
|
|
497
|
+
|
|
498
|
+
set_workspace(args.nc_workspace)
|
|
499
|
+
quantization_config = PostTrainingQuantConfig(
|
|
500
|
+
calibration_sampling_size=[args.calibration_sampling_size],
|
|
501
|
+
recipes={
|
|
502
|
+
"optypes_to_exclude_output_quant": ["MatMul"],
|
|
503
|
+
"smooth_quant": True,
|
|
504
|
+
"smooth_quant_args": {"alpha": args.smooth_quant_alpha},
|
|
505
|
+
},
|
|
506
|
+
op_type_dict={
|
|
507
|
+
"^((?!(MatMul|Gather|Conv)).)*$": {
|
|
508
|
+
"weight": {"dtype": ["fp32"]},
|
|
509
|
+
"activation": {"dtype": ["fp32"]},
|
|
510
|
+
}
|
|
511
|
+
},
|
|
512
|
+
)
|
|
513
|
+
|
|
514
|
+
# Convert decoder_model.onnx to INT8
|
|
515
|
+
decoder_model_int8 = intel_quantization.fit(
|
|
516
|
+
decoder_model_fp32_path,
|
|
517
|
+
quantization_config,
|
|
518
|
+
calib_dataloader=QuantKVDataLoader(args),
|
|
519
|
+
)
|
|
520
|
+
load_external_data_for_model(
|
|
521
|
+
decoder_model_int8._model,
|
|
522
|
+
os.path.split(decoder_model_int8._model_path)[0],
|
|
523
|
+
)
|
|
524
|
+
save_onnx_model(
|
|
525
|
+
decoder_model_int8._model,
|
|
526
|
+
decoder_model_int8_path,
|
|
527
|
+
f"{args.model_name}_decoder_model_int8.onnx.data",
|
|
528
|
+
)
|
|
529
|
+
del decoder_model_int8
|
|
530
|
+
logger.info(
|
|
531
|
+
f"The ONNX model at {decoder_model_fp32_path} has been quantized to int8 and saved at {decoder_model_int8_path}!"
|
|
532
|
+
)
|
|
533
|
+
remove_existing_model(decoder_model_fp32_path)
|
|
534
|
+
|
|
535
|
+
# Convert decoder_with_past_model.onnx to INT8
|
|
536
|
+
decoder_with_past_model_int8 = intel_quantization.fit(
|
|
537
|
+
decoder_with_past_model_fp32_path,
|
|
538
|
+
quantization_config,
|
|
539
|
+
calib_dataloader=QuantKVDataLoader(args, onnx_model_path=decoder_model_fp32_path),
|
|
540
|
+
)
|
|
541
|
+
load_external_data_for_model(
|
|
542
|
+
decoder_with_past_model_int8._model,
|
|
543
|
+
os.path.split(decoder_with_past_model_int8._model_path)[0],
|
|
544
|
+
)
|
|
545
|
+
save_onnx_model(
|
|
546
|
+
decoder_with_past_model_int8._model,
|
|
547
|
+
decoder_with_past_model_int8_path,
|
|
548
|
+
f"{args.model_name}_decoder_with_past_model_int8.onnx.data",
|
|
549
|
+
)
|
|
550
|
+
del decoder_with_past_model_int8
|
|
551
|
+
logger.info(
|
|
552
|
+
f"The ONNX model at {decoder_with_past_model_fp32_path} has been quantized to int8 and saved at {decoder_with_past_model_int8_path}!"
|
|
553
|
+
)
|
|
554
|
+
remove_existing_model(decoder_with_past_model_fp32_path)
|
|
555
|
+
|
|
556
|
+
logger.info(f"The {args.model_name} ONNX model has been successfully quantized to int8!")
|
|
557
|
+
|
|
558
|
+
logger.warning(f"Removing {args.nc_workspace}")
|
|
559
|
+
shutil.rmtree(args.nc_workspace)
|
|
560
|
+
|
|
561
|
+
|
|
562
|
+
def remove_existing_model(model_path: str):
|
|
563
|
+
# Remove ONNX model and its external data
|
|
564
|
+
data_path = os.path.join(model_path + ".data")
|
|
565
|
+
os.remove(model_path)
|
|
566
|
+
os.remove(data_path)
|
|
567
|
+
logger.warning(f"Removed {model_path} and {data_path}")
|
|
568
|
+
|
|
569
|
+
|
|
570
|
+
def remove_existing_files(output_path: str):
|
|
571
|
+
for filename in os.listdir(output_path):
|
|
572
|
+
filepath = os.path.join(output_path, filename)
|
|
573
|
+
if ".onnx" in filename or ".onnx.data" in filename:
|
|
574
|
+
os.remove(filepath)
|
|
575
|
+
logger.warning(f"Removed {filepath}")
|
|
576
|
+
|
|
577
|
+
|
|
578
|
+
def optimize_optimum(config: AutoConfig, args: argparse.Namespace):
|
|
579
|
+
tmp_file = os.path.join(args.output, args.model_name + ".tmp.onnx")
|
|
580
|
+
output_file = os.path.join(args.output, args.model_name + ".onnx")
|
|
581
|
+
window_size = -1 if not hasattr(config, "sliding_window") else config.sliding_window
|
|
582
|
+
optimize_export(args, config, args.input, tmp_file, remove_model=False, window_size=window_size)
|
|
583
|
+
logger.info(f"Model successfully optimized to {tmp_file}")
|
|
584
|
+
opt_model = OnnxModel(onnx.load_model(tmp_file, load_external_data=True))
|
|
585
|
+
if args.precision == Precision.FLOAT16:
|
|
586
|
+
opt_model.convert_float_to_float16(keep_io_types=False)
|
|
587
|
+
logger.info("Model successfully fused and quantized to FP16!")
|
|
588
|
+
opt_model.save_model_to_file(output_file, use_external_data_format=True)
|
|
589
|
+
logger.info(f"Output model successfully saved to {output_file}")
|
|
590
|
+
logger.info(f"Removing {tmp_file}")
|
|
591
|
+
remove_existing_model(tmp_file)
|
|
592
|
+
|
|
593
|
+
|
|
594
|
+
def get_args():
|
|
595
|
+
parser = argparse.ArgumentParser()
|
|
596
|
+
|
|
597
|
+
parser.add_argument(
|
|
598
|
+
"-m",
|
|
599
|
+
"--model_name",
|
|
600
|
+
required=True,
|
|
601
|
+
help="Model name in Hugging Face",
|
|
602
|
+
)
|
|
603
|
+
|
|
604
|
+
parser.add_argument(
|
|
605
|
+
"-i",
|
|
606
|
+
"--input",
|
|
607
|
+
required=False,
|
|
608
|
+
default=os.path.join("."),
|
|
609
|
+
help="Directory path to PyTorch model and associated files if saved on disk, or ONNX model file location if optimize_optimum is passed.",
|
|
610
|
+
)
|
|
611
|
+
|
|
612
|
+
parser.add_argument(
|
|
613
|
+
"-o",
|
|
614
|
+
"--output",
|
|
615
|
+
required=False,
|
|
616
|
+
default=os.path.join(".", "llama_onnx_models"),
|
|
617
|
+
help="Directory path to save exported model files in",
|
|
618
|
+
)
|
|
619
|
+
|
|
620
|
+
parser.add_argument(
|
|
621
|
+
"-p",
|
|
622
|
+
"--precision",
|
|
623
|
+
required=False,
|
|
624
|
+
type=Precision,
|
|
625
|
+
default=Precision.FLOAT32,
|
|
626
|
+
choices=[Precision.FLOAT32, Precision.FLOAT16, Precision.INT8, Precision.INT4],
|
|
627
|
+
help="Precision to export model in",
|
|
628
|
+
)
|
|
629
|
+
|
|
630
|
+
parser.add_argument(
|
|
631
|
+
"-e",
|
|
632
|
+
"--execution_provider",
|
|
633
|
+
required=False,
|
|
634
|
+
default="cpu",
|
|
635
|
+
choices=["cpu", "cuda"],
|
|
636
|
+
help="Execution provider to verify parity with",
|
|
637
|
+
)
|
|
638
|
+
|
|
639
|
+
parser.add_argument(
|
|
640
|
+
"-r",
|
|
641
|
+
"--reexport",
|
|
642
|
+
required=False,
|
|
643
|
+
action="store_true",
|
|
644
|
+
help="Re-export models and overwrite existing models in output folder",
|
|
645
|
+
)
|
|
646
|
+
parser.set_defaults(reexport=False)
|
|
647
|
+
|
|
648
|
+
parser.add_argument(
|
|
649
|
+
"--use_gqa",
|
|
650
|
+
required=False,
|
|
651
|
+
action="store_true",
|
|
652
|
+
help="Use GroupQueryAttention instead of MultiHeadAttention",
|
|
653
|
+
)
|
|
654
|
+
parser.set_defaults(use_gqa=False)
|
|
655
|
+
|
|
656
|
+
parser.add_argument(
|
|
657
|
+
"--no_merged",
|
|
658
|
+
required=False,
|
|
659
|
+
action="store_true",
|
|
660
|
+
help="Export models into 2 ONNX files instead of 1. Deprecated in favor of exporting into 1 ONNX file.",
|
|
661
|
+
)
|
|
662
|
+
parser.set_defaults(no_merged=False)
|
|
663
|
+
|
|
664
|
+
parser.add_argument(
|
|
665
|
+
"-q",
|
|
666
|
+
"--quantization_method",
|
|
667
|
+
default="",
|
|
668
|
+
choices=["blockwise", "smooth_quant", "quantize_dynamic"],
|
|
669
|
+
help="Run a specific quantization algorithm (blockwise for int4, smooth_quant for int8, quantize_dynamic for int8). Blockwise is recommended. Need to install extra packages in `requirements-quant.txt` for SmoothQuant.",
|
|
670
|
+
)
|
|
671
|
+
|
|
672
|
+
blockwise_group = parser.add_argument_group("blockwise (4-bit quantization)")
|
|
673
|
+
|
|
674
|
+
parser.add_argument("--bits", default=4, type=int, help="the target bits to represent weight")
|
|
675
|
+
|
|
676
|
+
blockwise_group.add_argument(
|
|
677
|
+
"--block_size",
|
|
678
|
+
required=False,
|
|
679
|
+
default=32,
|
|
680
|
+
type=int,
|
|
681
|
+
help="Block size to quantize with. See https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/python/tools/quantization/matmul_nbits_quantizer.py for details.",
|
|
682
|
+
)
|
|
683
|
+
|
|
684
|
+
blockwise_group.add_argument(
|
|
685
|
+
"--int4_accuracy_level",
|
|
686
|
+
required=False,
|
|
687
|
+
type=int,
|
|
688
|
+
help="Accuracy level of the 4-bit quantized MatMul computation. "
|
|
689
|
+
"Refer to the MatMulNBits contrib op's 'accuracy_level' attribute for details "
|
|
690
|
+
"(https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#commicrosoftmatmulnbits).",
|
|
691
|
+
)
|
|
692
|
+
|
|
693
|
+
smooth_quant_group = parser.add_argument_group("smooth_quant (8-bit quantization)")
|
|
694
|
+
|
|
695
|
+
smooth_quant_group.add_argument(
|
|
696
|
+
"--smooth_quant_alpha",
|
|
697
|
+
required=False,
|
|
698
|
+
default=0.8,
|
|
699
|
+
type=float,
|
|
700
|
+
help="Strength to control migration difficulty from activation to weights. Default is 0.8 to match value \
|
|
701
|
+
used in original paper for LLaMA. Paper recommends using values in [0.4, 0.6] range. \
|
|
702
|
+
Link to paper: https://arxiv.org/pdf/2211.10438.pdf",
|
|
703
|
+
)
|
|
704
|
+
|
|
705
|
+
smooth_quant_group.add_argument(
|
|
706
|
+
"--smooth_quant_dataset",
|
|
707
|
+
required=False,
|
|
708
|
+
default="NeelNanda/pile-10k",
|
|
709
|
+
help="Path to dataset for calibration during quantization",
|
|
710
|
+
)
|
|
711
|
+
|
|
712
|
+
smooth_quant_group.add_argument(
|
|
713
|
+
"--pad_max",
|
|
714
|
+
required=False,
|
|
715
|
+
default=196,
|
|
716
|
+
type=int,
|
|
717
|
+
help="Max padding size",
|
|
718
|
+
)
|
|
719
|
+
|
|
720
|
+
smooth_quant_group.add_argument(
|
|
721
|
+
"--calibration_sampling_size",
|
|
722
|
+
required=False,
|
|
723
|
+
type=int,
|
|
724
|
+
default=8,
|
|
725
|
+
help="Calibration sampling size for quantization config",
|
|
726
|
+
)
|
|
727
|
+
|
|
728
|
+
smooth_quant_group.add_argument(
|
|
729
|
+
"--nc_workspace",
|
|
730
|
+
required=False,
|
|
731
|
+
type=str,
|
|
732
|
+
default=os.path.join(".", "nc_workspace"),
|
|
733
|
+
help="Workspace to save intermediate files generated by Intel's Neural Compressor package.",
|
|
734
|
+
)
|
|
735
|
+
|
|
736
|
+
quantize_dynamic_group = parser.add_argument_group("quantize_dynamic (8-bit quantization)")
|
|
737
|
+
|
|
738
|
+
quantize_dynamic_group.add_argument(
|
|
739
|
+
"--quantize_embedding_layer",
|
|
740
|
+
required=False,
|
|
741
|
+
action="store_true",
|
|
742
|
+
help="Quantize MatMul, GEMM, and Gather.",
|
|
743
|
+
)
|
|
744
|
+
quantize_dynamic_group.set_defaults(quantize_embedding_layer=False)
|
|
745
|
+
|
|
746
|
+
quantize_dynamic_group.add_argument(
|
|
747
|
+
"--quantize_per_channel",
|
|
748
|
+
required=False,
|
|
749
|
+
action="store_true",
|
|
750
|
+
help="Quantize weights per each channel.",
|
|
751
|
+
)
|
|
752
|
+
quantize_dynamic_group.set_defaults(quantize_per_channel=False)
|
|
753
|
+
|
|
754
|
+
quantize_dynamic_group.add_argument(
|
|
755
|
+
"--quantize_reduce_range",
|
|
756
|
+
required=False,
|
|
757
|
+
action="store_true",
|
|
758
|
+
help="Quantize weights with 7 bits.",
|
|
759
|
+
)
|
|
760
|
+
quantize_dynamic_group.set_defaults(quantize_reduce_range=False)
|
|
761
|
+
|
|
762
|
+
parser.add_argument(
|
|
763
|
+
"-v",
|
|
764
|
+
"--verbose",
|
|
765
|
+
action="store_true",
|
|
766
|
+
help="Print verbose logs",
|
|
767
|
+
)
|
|
768
|
+
parser.set_defaults(verbose=False)
|
|
769
|
+
|
|
770
|
+
parser.add_argument(
|
|
771
|
+
"-d",
|
|
772
|
+
"--use_dynamo_export",
|
|
773
|
+
action="store_true",
|
|
774
|
+
help="Use the new Dynamo exporter instead of the old TorchScript exporter",
|
|
775
|
+
)
|
|
776
|
+
parser.set_defaults(use_dynamo_export=False)
|
|
777
|
+
|
|
778
|
+
parser.add_argument(
|
|
779
|
+
"--cache_dir",
|
|
780
|
+
required=False,
|
|
781
|
+
type=str,
|
|
782
|
+
default="./model_cache",
|
|
783
|
+
help="model cache dir to override default HF cache dir to avoid overflood the /home dir",
|
|
784
|
+
)
|
|
785
|
+
|
|
786
|
+
parser.add_argument(
|
|
787
|
+
"--optimize_optimum",
|
|
788
|
+
action="store_true",
|
|
789
|
+
help="Avoid exporting model, only apply quantizations and optimizations to existing model exported from optimum.",
|
|
790
|
+
)
|
|
791
|
+
|
|
792
|
+
parser.add_argument(
|
|
793
|
+
"--small_gpu",
|
|
794
|
+
action="store_true",
|
|
795
|
+
help="Load the llama in GPU every time for parity_check if it's running in a machine which GPU memory < 36GB.",
|
|
796
|
+
)
|
|
797
|
+
|
|
798
|
+
parser.set_defaults(optimize_optimum=False)
|
|
799
|
+
|
|
800
|
+
args = parser.parse_args()
|
|
801
|
+
return args
|
|
802
|
+
|
|
803
|
+
|
|
804
|
+
def main():
|
|
805
|
+
warnings.warn(
|
|
806
|
+
"This example is deprecated. Use the Olive recipe instead: "
|
|
807
|
+
"https://github.com/microsoft/olive-recipes/tree/main",
|
|
808
|
+
DeprecationWarning,
|
|
809
|
+
stacklevel=2,
|
|
810
|
+
)
|
|
811
|
+
if version.parse(torch.__version__) < version.parse("2.2.0"):
|
|
812
|
+
logger.error(f"Detected PyTorch version {torch.__version__}. Please upgrade and use v2.2.0 or newer.")
|
|
813
|
+
return
|
|
814
|
+
|
|
815
|
+
args = get_args()
|
|
816
|
+
setup_logger(args.verbose)
|
|
817
|
+
prepare_environment(args.input, args.output, args.execution_provider != "cpu")
|
|
818
|
+
if args.reexport:
|
|
819
|
+
remove_existing_files(args.output)
|
|
820
|
+
logger.info(f"Arguments: {args}")
|
|
821
|
+
|
|
822
|
+
world_size = get_size()
|
|
823
|
+
rank = get_rank()
|
|
824
|
+
args.world_size = world_size
|
|
825
|
+
|
|
826
|
+
# Load model and config
|
|
827
|
+
use_auth_token = args.input == os.path.join(".")
|
|
828
|
+
setattr(args, "use_auth_token", use_auth_token) # noqa: B010
|
|
829
|
+
|
|
830
|
+
original_model_name = args.model_name
|
|
831
|
+
setattr(args, "original_model_name", original_model_name) # noqa: B010
|
|
832
|
+
args.model_name = args.model_name.split("/")[-1]
|
|
833
|
+
|
|
834
|
+
setattr(args, "device_name", "cpu" if args.execution_provider == "cpu" else f"cuda:{rank}") # noqa: B010
|
|
835
|
+
setattr(args, "device", torch.device(args.device_name)) # noqa: B010
|
|
836
|
+
|
|
837
|
+
location = args.original_model_name if use_auth_token else args.input
|
|
838
|
+
|
|
839
|
+
if args.optimize_optimum:
|
|
840
|
+
config = AutoConfig.from_pretrained(args.original_model_name, cache_dir=args.cache_dir)
|
|
841
|
+
optimize_optimum(config, args)
|
|
842
|
+
return
|
|
843
|
+
|
|
844
|
+
# Use CUDA for LLaMA-2-70B to speed up export and CPU for other models
|
|
845
|
+
l_config, llama = setup_torch_model(
|
|
846
|
+
args, location, use_auth_token, device=args.device if args.model_name == "Llama-2-70b-hf" else None
|
|
847
|
+
)
|
|
848
|
+
|
|
849
|
+
assert l_config.num_attention_heads % world_size == 0 and l_config.num_key_value_heads % world_size == 0
|
|
850
|
+
|
|
851
|
+
barrier()
|
|
852
|
+
for i in range(world_size):
|
|
853
|
+
if i == rank:
|
|
854
|
+
# Set model paths for FP32 model
|
|
855
|
+
decoder_model_fp32_path = os.path.join(
|
|
856
|
+
args.output, f"rank_{rank}_{args.model_name}_decoder_model_fp32.onnx"
|
|
857
|
+
)
|
|
858
|
+
decoder_with_past_model_fp32_path = os.path.join(
|
|
859
|
+
args.output, f"rank_{rank}_{args.model_name}_decoder_with_past_model_fp32.onnx"
|
|
860
|
+
)
|
|
861
|
+
decoder_merged_model_fp32_path = os.path.join(
|
|
862
|
+
args.output, f"rank_{rank}_{args.model_name}_decoder_merged_model_fp32.onnx"
|
|
863
|
+
)
|
|
864
|
+
old_paths = [decoder_model_fp32_path, decoder_with_past_model_fp32_path, decoder_merged_model_fp32_path]
|
|
865
|
+
|
|
866
|
+
missing_separate_exports = (
|
|
867
|
+
args.no_merged
|
|
868
|
+
and not os.path.exists(decoder_model_fp32_path)
|
|
869
|
+
and not os.path.exists(decoder_with_past_model_fp32_path)
|
|
870
|
+
)
|
|
871
|
+
missing_merged_export = not args.no_merged and not os.path.exists(decoder_merged_model_fp32_path)
|
|
872
|
+
|
|
873
|
+
# Export to ONNX
|
|
874
|
+
if missing_separate_exports or missing_merged_export:
|
|
875
|
+
if args.use_dynamo_export:
|
|
876
|
+
logger.warning("Please ensure you have installed PyTorch, ONNX, and ONNX Script as follows.")
|
|
877
|
+
logger.warning("Step 1 - PyTorch nightly: https://pytorch.org/get-started/locally/")
|
|
878
|
+
logger.warning("Step 2 - ONNX weekly: https://pypi.org/project/onnx-weekly/")
|
|
879
|
+
logger.warning(
|
|
880
|
+
"Step 3 - ONNX Script from source: https://github.com/microsoft/onnxscript#installing-onnx-script"
|
|
881
|
+
)
|
|
882
|
+
logger.warning(
|
|
883
|
+
"Note: After you install ONNX weekly, omit `onnx` when running the first line for installing ONNX Script. This is because you already installed `onnx-weekly` in the previous step."
|
|
884
|
+
)
|
|
885
|
+
run_dynamo_export(args, l_config, llama)
|
|
886
|
+
elif args.no_merged:
|
|
887
|
+
run_torchscript_separate_export(args, l_config, llama, rank, world_size)
|
|
888
|
+
else:
|
|
889
|
+
run_torchscript_merged_export(args, l_config, llama, rank, world_size)
|
|
890
|
+
del llama # Delete LLaMA model from memory since it will be loaded again during parity check
|
|
891
|
+
|
|
892
|
+
# Set model paths to store FP32 optimized model
|
|
893
|
+
decoder_model_fp32_opt_path = os.path.join(
|
|
894
|
+
args.output, f"rank_{rank}_{args.model_name}_decoder_model_fp32_opt.onnx"
|
|
895
|
+
)
|
|
896
|
+
decoder_with_past_model_fp32_opt_path = os.path.join(
|
|
897
|
+
args.output, f"rank_{rank}_{args.model_name}_decoder_with_past_model_fp32_opt.onnx"
|
|
898
|
+
)
|
|
899
|
+
decoder_merged_model_fp32_opt_path = os.path.join(
|
|
900
|
+
args.output, f"rank_{rank}_{args.model_name}_decoder_merged_model_fp32_opt.onnx"
|
|
901
|
+
)
|
|
902
|
+
new_paths = [
|
|
903
|
+
decoder_model_fp32_opt_path,
|
|
904
|
+
decoder_with_past_model_fp32_opt_path,
|
|
905
|
+
decoder_merged_model_fp32_opt_path,
|
|
906
|
+
]
|
|
907
|
+
|
|
908
|
+
# Run the optimizer script.
|
|
909
|
+
logger.info("Optimizing models...")
|
|
910
|
+
for orig_path, opt_path in zip(old_paths, new_paths, strict=False):
|
|
911
|
+
if os.path.exists(orig_path):
|
|
912
|
+
optimize_export(args, l_config, input_path=orig_path, output_path=opt_path, world_size=world_size)
|
|
913
|
+
|
|
914
|
+
# Re-assign default FP32 model paths as their optimized versions
|
|
915
|
+
decoder_model_fp32_path = decoder_model_fp32_opt_path
|
|
916
|
+
decoder_with_past_model_fp32_path = decoder_with_past_model_fp32_opt_path
|
|
917
|
+
decoder_merged_model_fp32_path = decoder_merged_model_fp32_opt_path
|
|
918
|
+
old_paths = [decoder_model_fp32_path, decoder_with_past_model_fp32_path, decoder_merged_model_fp32_path]
|
|
919
|
+
|
|
920
|
+
logger.info(
|
|
921
|
+
f"The {args.model_name} ONNX model has been successfully optimized with the ORT transformer optimizer script!"
|
|
922
|
+
)
|
|
923
|
+
|
|
924
|
+
# Change precision of exported models from FP32
|
|
925
|
+
if args.precision == Precision.FLOAT16:
|
|
926
|
+
new_paths = convert_to_float16(args, old_paths, rank)
|
|
927
|
+
|
|
928
|
+
elif args.precision == Precision.INT8:
|
|
929
|
+
decoder_model_int8_path = os.path.join(
|
|
930
|
+
args.output, f"rank_{rank}_{args.model_name}_decoder_model_int8.onnx"
|
|
931
|
+
)
|
|
932
|
+
decoder_with_past_model_int8_path = os.path.join(
|
|
933
|
+
args.output, f"rank_{rank}_{args.model_name}_decoder_with_past_model_int8.onnx"
|
|
934
|
+
)
|
|
935
|
+
decoder_merged_model_int8_path = os.path.join(
|
|
936
|
+
args.output, f"rank_{rank}_{args.model_name}_decoder_merged_model_int8.onnx"
|
|
937
|
+
)
|
|
938
|
+
new_paths = [decoder_model_int8_path, decoder_with_past_model_int8_path, decoder_merged_model_int8_path]
|
|
939
|
+
|
|
940
|
+
if args.quantization_method == "smooth_quant":
|
|
941
|
+
if not args.no_merged:
|
|
942
|
+
logger.error("SmoothQuant must be used on separately exported models")
|
|
943
|
+
else:
|
|
944
|
+
logger.info(
|
|
945
|
+
f"Quantizing {decoder_model_fp32_path} and {decoder_with_past_model_fp32_path} to int8"
|
|
946
|
+
)
|
|
947
|
+
smooth_quant(args, old_paths[0], old_paths[1], new_paths[0], new_paths[1])
|
|
948
|
+
|
|
949
|
+
elif args.quantization_method == "quantize_dynamic":
|
|
950
|
+
logger.warning(
|
|
951
|
+
"The `quantize_dynamic` method is deprecated in favor of `smooth_quant` instead. Precision loss may be high with `quantize_dynamic`."
|
|
952
|
+
)
|
|
953
|
+
|
|
954
|
+
logger.info("Quantizing to int8...")
|
|
955
|
+
for fp32_path, int8_path in zip(old_paths, new_paths, strict=False):
|
|
956
|
+
if os.path.exists(fp32_path):
|
|
957
|
+
ort_quantization.quantize_dynamic(
|
|
958
|
+
fp32_path,
|
|
959
|
+
int8_path,
|
|
960
|
+
op_types_to_quantize=(
|
|
961
|
+
["MatMul", "Gemm", "Gather"]
|
|
962
|
+
if args.quantize_embedding_layer
|
|
963
|
+
else ["MatMul", "Gemm"]
|
|
964
|
+
),
|
|
965
|
+
per_channel=args.quantize_per_channel,
|
|
966
|
+
reduce_range=args.quantize_reduce_range,
|
|
967
|
+
use_external_data_format=True,
|
|
968
|
+
extra_options={"MatMulConstBOnly": True},
|
|
969
|
+
)
|
|
970
|
+
logger.info(
|
|
971
|
+
f"The ONNX model at {fp32_path} has been quantized to int8 and saved at {int8_path}!"
|
|
972
|
+
)
|
|
973
|
+
remove_existing_model(decoder_model_fp32_path)
|
|
974
|
+
|
|
975
|
+
logger.info(f"The {args.model_name} ONNX model has been successfully quantized to int8!")
|
|
976
|
+
|
|
977
|
+
else:
|
|
978
|
+
raise Exception(f"Could not recognize {args.quantization_method} as a quantization method")
|
|
979
|
+
|
|
980
|
+
elif args.precision == Precision.INT4:
|
|
981
|
+
if args.execution_provider != "cpu":
|
|
982
|
+
old_paths = convert_to_float16(args, old_paths, rank)
|
|
983
|
+
|
|
984
|
+
decoder_model_int4_path = os.path.join(
|
|
985
|
+
args.output, f"rank_{rank}_{args.model_name}_decoder_model_int4.onnx"
|
|
986
|
+
)
|
|
987
|
+
decoder_with_past_model_int4_path = os.path.join(
|
|
988
|
+
args.output, f"rank_{rank}_{args.model_name}_decoder_with_past_model_int4.onnx"
|
|
989
|
+
)
|
|
990
|
+
decoder_merged_model_int4_path = os.path.join(
|
|
991
|
+
args.output, f"rank_{rank}_{args.model_name}_decoder_merged_model_int4.onnx"
|
|
992
|
+
)
|
|
993
|
+
new_paths = [decoder_model_int4_path, decoder_with_past_model_int4_path, decoder_merged_model_int4_path]
|
|
994
|
+
|
|
995
|
+
for fp_path, int4_path in zip(old_paths, new_paths, strict=False):
|
|
996
|
+
if os.path.exists(fp_path):
|
|
997
|
+
model = onnx.load_model(fp_path, load_external_data=True)
|
|
998
|
+
quant = MatMulNBitsQuantizer(
|
|
999
|
+
model=model,
|
|
1000
|
+
bits=args.bits,
|
|
1001
|
+
block_size=args.block_size,
|
|
1002
|
+
is_symmetric=True,
|
|
1003
|
+
accuracy_level=args.int4_accuracy_level,
|
|
1004
|
+
nodes_to_exclude=[],
|
|
1005
|
+
)
|
|
1006
|
+
quant.process()
|
|
1007
|
+
quant.model.save_model_to_file(int4_path, use_external_data_format=True)
|
|
1008
|
+
del model
|
|
1009
|
+
del quant
|
|
1010
|
+
logger.info(f"The ONNX model at {fp_path} has been quantized to int4 and saved at {int4_path}!")
|
|
1011
|
+
remove_existing_model(fp_path)
|
|
1012
|
+
barrier()
|
|
1013
|
+
|
|
1014
|
+
logger.info("Verifying parity on all ONNX models created")
|
|
1015
|
+
|
|
1016
|
+
# Use FP32 precision for FP32, INT8, INT4 CPU models, use FP16 precision for FP16 and INT4 GPU models
|
|
1017
|
+
args.precision = (
|
|
1018
|
+
"fp32"
|
|
1019
|
+
if args.precision in {Precision.INT8, Precision.FLOAT32}
|
|
1020
|
+
or (args.precision == Precision.INT4 and args.execution_provider == "cpu")
|
|
1021
|
+
else "fp16"
|
|
1022
|
+
)
|
|
1023
|
+
|
|
1024
|
+
# Verify parity on all saved ONNX models
|
|
1025
|
+
for filename in os.listdir(args.output):
|
|
1026
|
+
if (
|
|
1027
|
+
".data" in filename
|
|
1028
|
+
or ".onnx" not in filename
|
|
1029
|
+
or args.precision not in filename
|
|
1030
|
+
or f"rank_{rank}" not in filename
|
|
1031
|
+
):
|
|
1032
|
+
continue
|
|
1033
|
+
|
|
1034
|
+
parity_cmd = [
|
|
1035
|
+
"-m",
|
|
1036
|
+
original_model_name,
|
|
1037
|
+
"-o",
|
|
1038
|
+
os.path.join(args.output, filename),
|
|
1039
|
+
"-ep",
|
|
1040
|
+
args.execution_provider,
|
|
1041
|
+
"--precision",
|
|
1042
|
+
args.precision,
|
|
1043
|
+
"--cache_dir",
|
|
1044
|
+
args.cache_dir,
|
|
1045
|
+
"--torch_model_directory",
|
|
1046
|
+
args.input,
|
|
1047
|
+
]
|
|
1048
|
+
if args.small_gpu:
|
|
1049
|
+
parity_cmd.append("--small_gpu")
|
|
1050
|
+
if "with_past" in filename:
|
|
1051
|
+
parity_cmd.append("--use_past_kv")
|
|
1052
|
+
if "merged" in filename:
|
|
1053
|
+
parity_cmd.append("--merged")
|
|
1054
|
+
|
|
1055
|
+
try:
|
|
1056
|
+
logger.info(f"check parity with cmd: {parity_cmd}")
|
|
1057
|
+
parity_check(parity_cmd)
|
|
1058
|
+
except Exception as e:
|
|
1059
|
+
logger.exception(f"An error occurred while verifying parity: {e}")
|
|
1060
|
+
sys.exit(-1)
|
|
1061
|
+
|
|
1062
|
+
|
|
1063
|
+
if __name__ == "__main__":
|
|
1064
|
+
main()
|