onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxruntime/LICENSE +21 -0
- onnxruntime/Privacy.md +21 -0
- onnxruntime/ThirdPartyNotices.txt +6121 -0
- onnxruntime/__init__.py +418 -0
- onnxruntime/backend/__init__.py +6 -0
- onnxruntime/backend/backend.py +175 -0
- onnxruntime/backend/backend_rep.py +52 -0
- onnxruntime/capi/DirectML.dll +0 -0
- onnxruntime/capi/__init__.py +4 -0
- onnxruntime/capi/_ld_preload.py +7 -0
- onnxruntime/capi/_pybind_state.py +33 -0
- onnxruntime/capi/build_and_package_info.py +2 -0
- onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
- onnxruntime/capi/onnxruntime.dll +0 -0
- onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
- onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
- onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
- onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
- onnxruntime/capi/onnxruntime_validation.py +154 -0
- onnxruntime/capi/version_info.py +2 -0
- onnxruntime/datasets/__init__.py +18 -0
- onnxruntime/datasets/logreg_iris.onnx +0 -0
- onnxruntime/datasets/mul_1.onnx +0 -0
- onnxruntime/datasets/sigmoid.onnx +13 -0
- onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
- onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
- onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
- onnxruntime/quantization/__init__.py +19 -0
- onnxruntime/quantization/base_quantizer.py +529 -0
- onnxruntime/quantization/calibrate.py +1267 -0
- onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
- onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
- onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
- onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
- onnxruntime/quantization/fusions/__init__.py +4 -0
- onnxruntime/quantization/fusions/fusion.py +311 -0
- onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
- onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
- onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
- onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
- onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
- onnxruntime/quantization/neural_compressor/__init__.py +1 -0
- onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
- onnxruntime/quantization/neural_compressor/util.py +80 -0
- onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
- onnxruntime/quantization/onnx_model.py +600 -0
- onnxruntime/quantization/onnx_quantizer.py +1163 -0
- onnxruntime/quantization/operators/__init__.py +2 -0
- onnxruntime/quantization/operators/activation.py +119 -0
- onnxruntime/quantization/operators/argmax.py +18 -0
- onnxruntime/quantization/operators/attention.py +73 -0
- onnxruntime/quantization/operators/base_operator.py +26 -0
- onnxruntime/quantization/operators/binary_op.py +72 -0
- onnxruntime/quantization/operators/concat.py +62 -0
- onnxruntime/quantization/operators/conv.py +260 -0
- onnxruntime/quantization/operators/direct_q8.py +78 -0
- onnxruntime/quantization/operators/embed_layernorm.py +121 -0
- onnxruntime/quantization/operators/gather.py +64 -0
- onnxruntime/quantization/operators/gavgpool.py +62 -0
- onnxruntime/quantization/operators/gemm.py +172 -0
- onnxruntime/quantization/operators/lstm.py +121 -0
- onnxruntime/quantization/operators/matmul.py +231 -0
- onnxruntime/quantization/operators/maxpool.py +34 -0
- onnxruntime/quantization/operators/norm.py +40 -0
- onnxruntime/quantization/operators/pad.py +172 -0
- onnxruntime/quantization/operators/pooling.py +67 -0
- onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
- onnxruntime/quantization/operators/resize.py +34 -0
- onnxruntime/quantization/operators/softmax.py +74 -0
- onnxruntime/quantization/operators/split.py +63 -0
- onnxruntime/quantization/operators/where.py +87 -0
- onnxruntime/quantization/preprocess.py +141 -0
- onnxruntime/quantization/qdq_loss_debug.py +389 -0
- onnxruntime/quantization/qdq_quantizer.py +1477 -0
- onnxruntime/quantization/quant_utils.py +1051 -0
- onnxruntime/quantization/quantize.py +953 -0
- onnxruntime/quantization/registry.py +110 -0
- onnxruntime/quantization/shape_inference.py +204 -0
- onnxruntime/quantization/static_quantize_runner.py +256 -0
- onnxruntime/quantization/tensor_quant_overrides.py +520 -0
- onnxruntime/tools/__init__.py +10 -0
- onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
- onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
- onnxruntime/tools/file_utils.py +47 -0
- onnxruntime/tools/logger.py +11 -0
- onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
- onnxruntime/tools/mobile_helpers/__init__.py +0 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
- onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
- onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
- onnxruntime/tools/offline_tuning.py +169 -0
- onnxruntime/tools/onnx_model_utils.py +416 -0
- onnxruntime/tools/onnx_randomizer.py +85 -0
- onnxruntime/tools/onnxruntime_test.py +164 -0
- onnxruntime/tools/optimize_onnx_model.py +56 -0
- onnxruntime/tools/ort_format_model/__init__.py +27 -0
- onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
- onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
- onnxruntime/tools/ort_format_model/types.py +85 -0
- onnxruntime/tools/ort_format_model/utils.py +61 -0
- onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
- onnxruntime/tools/pytorch_export_helpers.py +131 -0
- onnxruntime/tools/qdq_helpers/__init__.py +0 -0
- onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
- onnxruntime/tools/qnn/add_trans_cast.py +292 -0
- onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
- onnxruntime/tools/qnn/preprocess.py +165 -0
- onnxruntime/tools/reduced_build_config_parser.py +203 -0
- onnxruntime/tools/remove_initializer_from_input.py +37 -0
- onnxruntime/tools/symbolic_shape_infer.py +3094 -0
- onnxruntime/tools/update_onnx_opset.py +31 -0
- onnxruntime/transformers/__init__.py +8 -0
- onnxruntime/transformers/affinity_helper.py +40 -0
- onnxruntime/transformers/benchmark.py +942 -0
- onnxruntime/transformers/benchmark_helper.py +643 -0
- onnxruntime/transformers/bert_perf_test.py +629 -0
- onnxruntime/transformers/bert_test_data.py +641 -0
- onnxruntime/transformers/compare_bert_results.py +256 -0
- onnxruntime/transformers/constants.py +47 -0
- onnxruntime/transformers/convert_generation.py +3605 -0
- onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
- onnxruntime/transformers/convert_to_packing_mode.py +385 -0
- onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
- onnxruntime/transformers/float16.py +501 -0
- onnxruntime/transformers/fusion_attention.py +1189 -0
- onnxruntime/transformers/fusion_attention_clip.py +340 -0
- onnxruntime/transformers/fusion_attention_sam2.py +533 -0
- onnxruntime/transformers/fusion_attention_unet.py +1307 -0
- onnxruntime/transformers/fusion_attention_vae.py +300 -0
- onnxruntime/transformers/fusion_bart_attention.py +435 -0
- onnxruntime/transformers/fusion_base.py +141 -0
- onnxruntime/transformers/fusion_bias_add.py +57 -0
- onnxruntime/transformers/fusion_biasgelu.py +66 -0
- onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
- onnxruntime/transformers/fusion_conformer_attention.py +222 -0
- onnxruntime/transformers/fusion_constant_fold.py +144 -0
- onnxruntime/transformers/fusion_embedlayer.py +810 -0
- onnxruntime/transformers/fusion_fastgelu.py +492 -0
- onnxruntime/transformers/fusion_gelu.py +258 -0
- onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
- onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
- onnxruntime/transformers/fusion_gpt_attention.py +546 -0
- onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
- onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
- onnxruntime/transformers/fusion_group_norm.py +180 -0
- onnxruntime/transformers/fusion_layernorm.py +489 -0
- onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
- onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
- onnxruntime/transformers/fusion_options.py +340 -0
- onnxruntime/transformers/fusion_qordered_attention.py +420 -0
- onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
- onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
- onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
- onnxruntime/transformers/fusion_quickgelu.py +74 -0
- onnxruntime/transformers/fusion_reshape.py +173 -0
- onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
- onnxruntime/transformers/fusion_shape.py +109 -0
- onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
- onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
- onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
- onnxruntime/transformers/fusion_transpose.py +167 -0
- onnxruntime/transformers/fusion_utils.py +321 -0
- onnxruntime/transformers/huggingface_models.py +74 -0
- onnxruntime/transformers/import_utils.py +20 -0
- onnxruntime/transformers/io_binding_helper.py +487 -0
- onnxruntime/transformers/large_model_exporter.py +395 -0
- onnxruntime/transformers/machine_info.py +230 -0
- onnxruntime/transformers/metrics.py +163 -0
- onnxruntime/transformers/models/bart/__init__.py +12 -0
- onnxruntime/transformers/models/bart/export.py +98 -0
- onnxruntime/transformers/models/bert/__init__.py +12 -0
- onnxruntime/transformers/models/bert/eval_squad.py +329 -0
- onnxruntime/transformers/models/gpt2/__init__.py +12 -0
- onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
- onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
- onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
- onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
- onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
- onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
- onnxruntime/transformers/models/llama/__init__.py +12 -0
- onnxruntime/transformers/models/llama/benchmark.py +700 -0
- onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
- onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
- onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
- onnxruntime/transformers/models/llama/dist_settings.py +57 -0
- onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
- onnxruntime/transformers/models/llama/llama_parity.py +343 -0
- onnxruntime/transformers/models/llama/llama_torch.py +47 -0
- onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
- onnxruntime/transformers/models/longformer/__init__.py +12 -0
- onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
- onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
- onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
- onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
- onnxruntime/transformers/models/phi2/__init__.py +12 -0
- onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
- onnxruntime/transformers/models/phi2/inference_example.py +414 -0
- onnxruntime/transformers/models/sam2/__init__.py +12 -0
- onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
- onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
- onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
- onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
- onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
- onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
- onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
- onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
- onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
- onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
- onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
- onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
- onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
- onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
- onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
- onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
- onnxruntime/transformers/models/t5/__init__.py +12 -0
- onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
- onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
- onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
- onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
- onnxruntime/transformers/models/t5/t5_helper.py +302 -0
- onnxruntime/transformers/models/whisper/__init__.py +12 -0
- onnxruntime/transformers/models/whisper/benchmark.py +585 -0
- onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
- onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
- onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
- onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
- onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
- onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
- onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
- onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
- onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
- onnxruntime/transformers/onnx_exporter.py +719 -0
- onnxruntime/transformers/onnx_model.py +1636 -0
- onnxruntime/transformers/onnx_model_bart.py +141 -0
- onnxruntime/transformers/onnx_model_bert.py +488 -0
- onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
- onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
- onnxruntime/transformers/onnx_model_clip.py +42 -0
- onnxruntime/transformers/onnx_model_conformer.py +32 -0
- onnxruntime/transformers/onnx_model_gpt2.py +101 -0
- onnxruntime/transformers/onnx_model_mmdit.py +112 -0
- onnxruntime/transformers/onnx_model_phi.py +929 -0
- onnxruntime/transformers/onnx_model_sam2.py +137 -0
- onnxruntime/transformers/onnx_model_t5.py +985 -0
- onnxruntime/transformers/onnx_model_tnlr.py +226 -0
- onnxruntime/transformers/onnx_model_unet.py +258 -0
- onnxruntime/transformers/onnx_model_vae.py +42 -0
- onnxruntime/transformers/onnx_utils.py +55 -0
- onnxruntime/transformers/optimizer.py +620 -0
- onnxruntime/transformers/past_helper.py +149 -0
- onnxruntime/transformers/profile_result_processor.py +358 -0
- onnxruntime/transformers/profiler.py +434 -0
- onnxruntime/transformers/quantize_helper.py +76 -0
- onnxruntime/transformers/shape_infer_helper.py +121 -0
- onnxruntime/transformers/shape_optimizer.py +400 -0
- onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
- onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
- onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
- onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
- onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
- onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,343 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License. See License.txt in the project root for
|
|
4
|
+
# license information.
|
|
5
|
+
# --------------------------------------------------------------------------
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
import argparse
|
|
9
|
+
import logging
|
|
10
|
+
import os
|
|
11
|
+
import time
|
|
12
|
+
|
|
13
|
+
import numpy as np
|
|
14
|
+
import packaging.version as pv
|
|
15
|
+
import torch
|
|
16
|
+
from benchmark_helper import setup_logger
|
|
17
|
+
from dist_settings import get_rank, get_size
|
|
18
|
+
from llama_inputs import (
|
|
19
|
+
add_io_bindings_as_ortvalues,
|
|
20
|
+
convert_inputs_for_ort,
|
|
21
|
+
get_merged_sample_with_past_kv_inputs,
|
|
22
|
+
get_sample_inputs,
|
|
23
|
+
get_sample_with_past_kv_inputs,
|
|
24
|
+
verify_ort_inputs,
|
|
25
|
+
)
|
|
26
|
+
from llama_torch import setup_torch_model
|
|
27
|
+
from models.torch_export_patches.cache_helper import make_dynamic_cache
|
|
28
|
+
from transformers import AutoConfig
|
|
29
|
+
from transformers import __version__ as transformers_version
|
|
30
|
+
from transformers.cache_utils import DynamicCache
|
|
31
|
+
|
|
32
|
+
import onnxruntime as ort
|
|
33
|
+
|
|
34
|
+
logger = logging.getLogger("")
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def get_sequence_lengths(args: argparse.Namespace, config: AutoConfig):
|
|
38
|
+
past_sequence_length, curr_sequence_length = (8, 1) if args.use_past_kv else (0, 8)
|
|
39
|
+
max_sequence_length = config.max_position_embeddings
|
|
40
|
+
return past_sequence_length, curr_sequence_length, max_sequence_length
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def get_inputs(args: argparse.Namespace, config: AutoConfig):
|
|
44
|
+
# Dummy values for parity
|
|
45
|
+
world_size = get_size()
|
|
46
|
+
batch_size = 2
|
|
47
|
+
past_sequence_length, sequence_length, max_sequence_length = get_sequence_lengths(args, config)
|
|
48
|
+
|
|
49
|
+
if args.merged:
|
|
50
|
+
inputs = get_merged_sample_with_past_kv_inputs(
|
|
51
|
+
config,
|
|
52
|
+
args.device,
|
|
53
|
+
batch_size,
|
|
54
|
+
seq_len=sequence_length,
|
|
55
|
+
past_seq_len=past_sequence_length,
|
|
56
|
+
max_seq_len=max_sequence_length,
|
|
57
|
+
use_fp16=args.use_fp16,
|
|
58
|
+
use_buffer_share=args.use_buffer_share,
|
|
59
|
+
return_dict=True,
|
|
60
|
+
world_size=world_size,
|
|
61
|
+
)
|
|
62
|
+
elif args.use_past_kv:
|
|
63
|
+
inputs = get_sample_with_past_kv_inputs(
|
|
64
|
+
config,
|
|
65
|
+
args.device,
|
|
66
|
+
batch_size,
|
|
67
|
+
sequence_length,
|
|
68
|
+
use_fp16=args.use_fp16,
|
|
69
|
+
return_dict=True,
|
|
70
|
+
world_size=world_size,
|
|
71
|
+
)
|
|
72
|
+
else:
|
|
73
|
+
inputs = get_sample_inputs(config, args.device, batch_size, sequence_length, return_dict=True)
|
|
74
|
+
|
|
75
|
+
return inputs
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def torch_deepcopy(value):
|
|
79
|
+
if isinstance(value, (int, float, str)):
|
|
80
|
+
return value
|
|
81
|
+
if isinstance(value, tuple):
|
|
82
|
+
return tuple(torch_deepcopy(v) for v in value)
|
|
83
|
+
if isinstance(value, list):
|
|
84
|
+
return [torch_deepcopy(v) for v in value]
|
|
85
|
+
if isinstance(value, set):
|
|
86
|
+
return {torch_deepcopy(v) for v in value}
|
|
87
|
+
if isinstance(value, dict):
|
|
88
|
+
return {k: torch_deepcopy(v) for k, v in value.items()}
|
|
89
|
+
if isinstance(value, np.ndarray):
|
|
90
|
+
return value.copy()
|
|
91
|
+
if hasattr(value, "clone"):
|
|
92
|
+
return value.clone()
|
|
93
|
+
if isinstance(value, DynamicCache):
|
|
94
|
+
return make_dynamic_cache(torch_deepcopy(list(zip(value.key_cache, value.value_cache, strict=False))))
|
|
95
|
+
# We should have a code using serialization, deserialization assuming a model
|
|
96
|
+
# cannot be exported without them.
|
|
97
|
+
raise NotImplementedError(f"torch_deepcopy not implemented for type {type(value)}")
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def verify_parity(
|
|
101
|
+
args: argparse.Namespace,
|
|
102
|
+
location: str,
|
|
103
|
+
use_auth_token: bool,
|
|
104
|
+
kv_cache_ortvalues: dict,
|
|
105
|
+
pytorch_model: None | torch.nn.Module = None,
|
|
106
|
+
config: None | AutoConfig = None,
|
|
107
|
+
):
|
|
108
|
+
# If it's running in a machine where GPU memory < 36GB, it should unload the model in GPU in time and free the GPU memory for ORT.
|
|
109
|
+
py_model = pytorch_model
|
|
110
|
+
if py_model is None:
|
|
111
|
+
config, py_model = setup_torch_model(
|
|
112
|
+
args,
|
|
113
|
+
location,
|
|
114
|
+
use_auth_token,
|
|
115
|
+
torch_dtype=(torch.float16 if args.use_fp16 else torch.float32),
|
|
116
|
+
device=args.device,
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
inputs = get_inputs(args, config)
|
|
120
|
+
|
|
121
|
+
if "past_key_values" in inputs and pv.Version(transformers_version) >= pv.Version("4.45"):
|
|
122
|
+
# Using DynamicCache
|
|
123
|
+
inputs["past_key_values"] = make_dynamic_cache(inputs["past_key_values"])
|
|
124
|
+
|
|
125
|
+
# Run inference with PyTorch
|
|
126
|
+
inputs_after_deepcopy = torch_deepcopy(inputs)
|
|
127
|
+
if args.execution_provider != "cpu":
|
|
128
|
+
torch.cuda.synchronize()
|
|
129
|
+
start_time = time.time()
|
|
130
|
+
# If there is a cache in the inputs, we need to make a copy as the model modifies them inplace.
|
|
131
|
+
# DynamicCache inherits from torch.nn.Module in some version of transformers.
|
|
132
|
+
# We need to make the copy manually.
|
|
133
|
+
pt_outputs = py_model(**inputs_after_deepcopy).logits.detach().cpu().numpy()
|
|
134
|
+
if args.execution_provider != "cpu":
|
|
135
|
+
torch.cuda.synchronize()
|
|
136
|
+
end_time = time.time()
|
|
137
|
+
logger.info(f"PyTorch took {end_time - start_time} s")
|
|
138
|
+
|
|
139
|
+
if args.small_gpu and py_model is not None:
|
|
140
|
+
del py_model
|
|
141
|
+
torch.cuda.empty_cache()
|
|
142
|
+
|
|
143
|
+
# Run inference with ORT
|
|
144
|
+
past_sequence_length, _, max_sequence_length = get_sequence_lengths(args, config)
|
|
145
|
+
inputs = convert_inputs_for_ort(
|
|
146
|
+
inputs,
|
|
147
|
+
use_buffer_share=args.use_buffer_share,
|
|
148
|
+
past_seq_len=past_sequence_length,
|
|
149
|
+
max_seq_len=max_sequence_length,
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
ep = f"{args.execution_provider.upper()}ExecutionProvider"
|
|
153
|
+
if ep == "CUDAExecutionProvider":
|
|
154
|
+
ep = (ep, {"device_id": args.rank})
|
|
155
|
+
ort_model = ort.InferenceSession(
|
|
156
|
+
args.onnx_model_path,
|
|
157
|
+
sess_options=ort.SessionOptions(),
|
|
158
|
+
providers=[ep],
|
|
159
|
+
)
|
|
160
|
+
inputs = verify_ort_inputs(ort_model, inputs)
|
|
161
|
+
|
|
162
|
+
# Add IO bindings for non-CPU execution providers
|
|
163
|
+
if args.execution_provider != "cpu":
|
|
164
|
+
io_binding, kv_cache_ortvalues = add_io_bindings_as_ortvalues(
|
|
165
|
+
ort_model,
|
|
166
|
+
ort_inputs=inputs,
|
|
167
|
+
device=args.execution_provider,
|
|
168
|
+
device_id=int(args.rank),
|
|
169
|
+
use_buffer_share=args.use_buffer_share,
|
|
170
|
+
kv_cache_ortvalues=kv_cache_ortvalues,
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
io_binding.synchronize_inputs()
|
|
174
|
+
start_time = time.time()
|
|
175
|
+
ort_model.run_with_iobinding(io_binding)
|
|
176
|
+
io_binding.synchronize_outputs()
|
|
177
|
+
end_time = time.time()
|
|
178
|
+
|
|
179
|
+
ort_outputs = io_binding.copy_outputs_to_cpu()[0] # Get logits
|
|
180
|
+
del ort_model
|
|
181
|
+
|
|
182
|
+
else:
|
|
183
|
+
start_time = time.time()
|
|
184
|
+
ort_outputs = ort_model.run(None, inputs)
|
|
185
|
+
end_time = time.time()
|
|
186
|
+
|
|
187
|
+
ort_outputs = ort_outputs[0] # Get logits
|
|
188
|
+
|
|
189
|
+
logger.info(f"ONNX Runtime took {end_time - start_time} s")
|
|
190
|
+
|
|
191
|
+
# Compare PyTorch and ONNX Runtime accuracy
|
|
192
|
+
tol = 2e1 if "int4" in args.onnx_model_path or "int8" in args.onnx_model_path else 5e-1
|
|
193
|
+
parity = np.allclose(pt_outputs, ort_outputs, rtol=tol, atol=tol)
|
|
194
|
+
logger.warning(f"Are PyTorch and ONNX Runtime results close? {parity}")
|
|
195
|
+
if not parity:
|
|
196
|
+
logger.warning(f"Max diff: {np.max(pt_outputs - ort_outputs)}")
|
|
197
|
+
return kv_cache_ortvalues
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
def get_args(argv: list[str]):
|
|
201
|
+
parser = argparse.ArgumentParser()
|
|
202
|
+
|
|
203
|
+
parser.add_argument(
|
|
204
|
+
"-m",
|
|
205
|
+
"--model_name",
|
|
206
|
+
required=False,
|
|
207
|
+
help="Model name in Hugging Face",
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
parser.add_argument(
|
|
211
|
+
"-t",
|
|
212
|
+
"--torch_model_directory",
|
|
213
|
+
required=False,
|
|
214
|
+
default=os.path.join("."),
|
|
215
|
+
help="Path to folder containing PyTorch model and associated files if saved on disk",
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
parser.add_argument(
|
|
219
|
+
"-o",
|
|
220
|
+
"--onnx_model_path",
|
|
221
|
+
required=True,
|
|
222
|
+
default=os.path.join("."),
|
|
223
|
+
help="Path to ONNX model (with external data files saved in the same folder as the model)",
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
parser.add_argument(
|
|
227
|
+
"-ep",
|
|
228
|
+
"--execution_provider",
|
|
229
|
+
required=False,
|
|
230
|
+
default="cpu",
|
|
231
|
+
choices=["cpu", "cuda"],
|
|
232
|
+
help="Execution provider to verify parity with",
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
parser.add_argument(
|
|
236
|
+
"-v",
|
|
237
|
+
"--verbose",
|
|
238
|
+
action="store_true",
|
|
239
|
+
help="Print verbose logs",
|
|
240
|
+
)
|
|
241
|
+
parser.set_defaults(verbose=False)
|
|
242
|
+
|
|
243
|
+
parser.add_argument(
|
|
244
|
+
"-p",
|
|
245
|
+
"--use_past_kv",
|
|
246
|
+
action="store_true",
|
|
247
|
+
help="Use past key and past value as inputs to the model. Necessary for decoder_with_past_model.onnx models.",
|
|
248
|
+
)
|
|
249
|
+
parser.set_defaults(use_past_kv=False)
|
|
250
|
+
|
|
251
|
+
parser.add_argument(
|
|
252
|
+
"-g",
|
|
253
|
+
"--use_buffer_share",
|
|
254
|
+
action="store_true",
|
|
255
|
+
help="Use if model has GroupQueryAttention and you want to enable past-present buffer sharing",
|
|
256
|
+
)
|
|
257
|
+
parser.set_defaults(use_buffer_share=False)
|
|
258
|
+
|
|
259
|
+
parser.add_argument(
|
|
260
|
+
"--merged",
|
|
261
|
+
action="store_true",
|
|
262
|
+
help="Use merged model (i.e. decoder_merged_model.onnx).",
|
|
263
|
+
)
|
|
264
|
+
parser.set_defaults(merged=False)
|
|
265
|
+
|
|
266
|
+
parser.add_argument(
|
|
267
|
+
"-fp",
|
|
268
|
+
"--precision",
|
|
269
|
+
required=True,
|
|
270
|
+
choices=["int4", "int8", "fp16", "fp32"],
|
|
271
|
+
help="Precision of model",
|
|
272
|
+
)
|
|
273
|
+
|
|
274
|
+
parser.add_argument(
|
|
275
|
+
"--cache_dir",
|
|
276
|
+
required=False,
|
|
277
|
+
type=str,
|
|
278
|
+
default="./model_cache",
|
|
279
|
+
help="model cache dir to override default HF cache dir to avoid overflood the /home dir",
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
# The argument is used for CI mainly, because the CI machine has 24G GPU memory at most.
|
|
283
|
+
parser.add_argument(
|
|
284
|
+
"--small_gpu",
|
|
285
|
+
action="store_true",
|
|
286
|
+
help="Load the llama in GPU every time for parity_check if it's running in a machine which GPU memory < 36GB. ",
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
args = parser.parse_args() if argv == [] else parser.parse_args(argv)
|
|
290
|
+
|
|
291
|
+
# Use FP32 precision for FP32, INT8, INT4 CPU models, use FP16 precision for FP16 and INT4 GPU models
|
|
292
|
+
args.precision = (
|
|
293
|
+
"fp32"
|
|
294
|
+
if args.precision in {"int8", "fp32"} or (args.precision == "int4" and args.execution_provider == "cpu")
|
|
295
|
+
else "fp16"
|
|
296
|
+
)
|
|
297
|
+
return args
|
|
298
|
+
|
|
299
|
+
|
|
300
|
+
def main(argv: list[str] = []): # noqa: B006
|
|
301
|
+
args = get_args(argv)
|
|
302
|
+
setup_logger(args.verbose)
|
|
303
|
+
logger.info(f"Arguments: {args}")
|
|
304
|
+
rank = get_rank()
|
|
305
|
+
|
|
306
|
+
# Load model and config
|
|
307
|
+
setattr(args, "use_fp16", args.precision == "fp16") # noqa: B010
|
|
308
|
+
args.rank = rank
|
|
309
|
+
setattr(args, "device_name", "cpu" if args.execution_provider == "cpu" else f"cuda:{rank}") # noqa: B010
|
|
310
|
+
setattr(args, "device", torch.device(args.device_name)) # noqa: B010
|
|
311
|
+
use_auth_token = args.torch_model_directory == os.path.join(".")
|
|
312
|
+
location = args.model_name if use_auth_token else args.torch_model_directory
|
|
313
|
+
|
|
314
|
+
kv_cache_ortvalues = {}
|
|
315
|
+
if not args.merged:
|
|
316
|
+
verify_parity(args, location, use_auth_token, kv_cache_ortvalues)
|
|
317
|
+
else:
|
|
318
|
+
config = llama = None
|
|
319
|
+
if not args.small_gpu:
|
|
320
|
+
config, llama = setup_torch_model(
|
|
321
|
+
args,
|
|
322
|
+
location,
|
|
323
|
+
use_auth_token,
|
|
324
|
+
torch_dtype=(torch.float16 if args.use_fp16 else torch.float32),
|
|
325
|
+
device=args.device,
|
|
326
|
+
)
|
|
327
|
+
|
|
328
|
+
# Verify prompt processing in merged model (decoder_model.onnx)
|
|
329
|
+
args.use_past_kv = False
|
|
330
|
+
kv_cache_ortvalues = verify_parity(
|
|
331
|
+
args, location, use_auth_token, kv_cache_ortvalues, pytorch_model=llama, config=config
|
|
332
|
+
)
|
|
333
|
+
|
|
334
|
+
# Verify token generation in merged model (decoder_with_past_model.onnx)
|
|
335
|
+
args.use_past_kv = True
|
|
336
|
+
verify_parity(args, location, use_auth_token, kv_cache_ortvalues, pytorch_model=llama, config=config)
|
|
337
|
+
|
|
338
|
+
|
|
339
|
+
if __name__ == "__main__":
|
|
340
|
+
seed = 2
|
|
341
|
+
np.random.seed(seed)
|
|
342
|
+
torch.manual_seed(seed)
|
|
343
|
+
main()
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License. See License.txt in the project root for
|
|
4
|
+
# license information.
|
|
5
|
+
# --------------------------------------------------------------------------
|
|
6
|
+
import logging
|
|
7
|
+
import os
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
from dist_settings import barrier, get_rank, get_size
|
|
11
|
+
from transformers import AutoConfig, AutoModelForCausalLM
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger("")
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def setup_torch_model(args, location, auth, torch_dtype=torch.float32, device=None):
|
|
17
|
+
world_size = get_size()
|
|
18
|
+
logger.info(f"world_size: {world_size}")
|
|
19
|
+
rank = get_rank()
|
|
20
|
+
barrier()
|
|
21
|
+
|
|
22
|
+
if not os.path.exists(args.cache_dir):
|
|
23
|
+
os.makedirs(args.cache_dir, exist_ok=True)
|
|
24
|
+
|
|
25
|
+
for i in range(world_size):
|
|
26
|
+
if i == rank % (world_size):
|
|
27
|
+
l_config = AutoConfig.from_pretrained(
|
|
28
|
+
location, use_auth_token=auth, cache_dir=args.cache_dir, trust_remote_code=auth
|
|
29
|
+
)
|
|
30
|
+
l_config.use_cache = True
|
|
31
|
+
l_config._attn_implementation = "eager" # "eager" uses LlamaAttention for attention layer
|
|
32
|
+
llama = AutoModelForCausalLM.from_pretrained(
|
|
33
|
+
location,
|
|
34
|
+
use_auth_token=auth,
|
|
35
|
+
trust_remote_code=auth,
|
|
36
|
+
config=l_config,
|
|
37
|
+
torch_dtype=torch_dtype,
|
|
38
|
+
cache_dir=args.cache_dir,
|
|
39
|
+
)
|
|
40
|
+
if world_size > 1:
|
|
41
|
+
llama.parallel_model()
|
|
42
|
+
if device:
|
|
43
|
+
llama.to(device)
|
|
44
|
+
llama.eval()
|
|
45
|
+
llama.requires_grad_(False)
|
|
46
|
+
barrier()
|
|
47
|
+
return l_config, llama
|
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License. See License.txt in the project root for
|
|
4
|
+
# license information.
|
|
5
|
+
# --------------------------------------------------------------------------
|
|
6
|
+
import argparse
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
import torch
|
|
10
|
+
from benchmark_helper import create_onnxruntime_session
|
|
11
|
+
from datasets import load_dataset
|
|
12
|
+
from llama_inputs import get_position_ids
|
|
13
|
+
from torch.nn.functional import pad
|
|
14
|
+
from torch.utils.data import DataLoader
|
|
15
|
+
from transformers import LlamaTokenizer
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class QuantKVDataLoader:
|
|
19
|
+
def __init__(self, args: argparse.Namespace, onnx_model_path: str = ""):
|
|
20
|
+
self.batch_size = 1
|
|
21
|
+
self.pad_max = args.pad_max
|
|
22
|
+
|
|
23
|
+
tokenizer = LlamaTokenizer.from_pretrained(args.original_model_name, use_auth_token=args.use_auth_token)
|
|
24
|
+
dataset = load_dataset(args.smooth_quant_dataset, split="train")
|
|
25
|
+
dataset = dataset.map(lambda examples: tokenizer(examples["text"]), batched=True)
|
|
26
|
+
dataset.set_format(type="torch", columns=["input_ids", "attention_mask"])
|
|
27
|
+
|
|
28
|
+
self.dataloader = DataLoader(
|
|
29
|
+
dataset,
|
|
30
|
+
batch_size=self.batch_size,
|
|
31
|
+
shuffle=False,
|
|
32
|
+
collate_fn=self.collate_batch,
|
|
33
|
+
)
|
|
34
|
+
self.decoder_model = (
|
|
35
|
+
create_onnxruntime_session(
|
|
36
|
+
onnx_model_path,
|
|
37
|
+
args.execution_provider != "cpu", # use_gpu
|
|
38
|
+
provider=args.execution_provider,
|
|
39
|
+
verbose=args.verbose,
|
|
40
|
+
)
|
|
41
|
+
if onnx_model_path
|
|
42
|
+
else None
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
def collate_batch(self, batch):
|
|
46
|
+
input_ids_batched = []
|
|
47
|
+
attention_mask_batched = []
|
|
48
|
+
position_ids_batched = []
|
|
49
|
+
labels = []
|
|
50
|
+
|
|
51
|
+
for text in batch:
|
|
52
|
+
# Set inputs for model
|
|
53
|
+
input_ids = text["input_ids"]
|
|
54
|
+
attention_mask = torch.ones(len(input_ids))
|
|
55
|
+
position_ids = get_position_ids(attention_mask, use_past_kv=False)
|
|
56
|
+
label = len(input_ids) - 1
|
|
57
|
+
|
|
58
|
+
# Pad input data because all model inputs must have same shape
|
|
59
|
+
pad_len = self.pad_max - input_ids.shape[0]
|
|
60
|
+
input_ids = pad(input_ids, (0, pad_len), value=1)
|
|
61
|
+
attention_mask = pad(attention_mask, (0, pad_len), value=0)
|
|
62
|
+
position_ids = pad(position_ids, (0, pad_len), value=0)
|
|
63
|
+
|
|
64
|
+
input_ids_batched.append(input_ids)
|
|
65
|
+
attention_mask_batched.append(attention_mask)
|
|
66
|
+
position_ids_batched.append(position_ids)
|
|
67
|
+
labels.append(label)
|
|
68
|
+
|
|
69
|
+
input_ids_batched = torch.vstack(input_ids_batched)
|
|
70
|
+
attention_mask_batched = torch.vstack(attention_mask_batched)
|
|
71
|
+
position_ids_batched = torch.vstack(position_ids_batched)
|
|
72
|
+
labels = torch.tensor(labels)
|
|
73
|
+
|
|
74
|
+
return (input_ids_batched, attention_mask_batched, position_ids_batched), labels
|
|
75
|
+
|
|
76
|
+
def __iter__(self):
|
|
77
|
+
try:
|
|
78
|
+
for (input_ids, attention_mask, position_ids), labels in self.dataloader:
|
|
79
|
+
# Inputs for decoder_model.onnx
|
|
80
|
+
inputs = {
|
|
81
|
+
"input_ids": input_ids[:, :-1].detach().cpu().numpy().astype(np.int64),
|
|
82
|
+
"attention_mask": attention_mask[:, :-1].detach().cpu().numpy().astype(np.int64),
|
|
83
|
+
"position_ids": position_ids[:, :-1].detach().cpu().numpy().astype(np.int64),
|
|
84
|
+
}
|
|
85
|
+
label = labels.detach().cpu().numpy()
|
|
86
|
+
|
|
87
|
+
if self.decoder_model is not None:
|
|
88
|
+
# Run decoder_model.onnx to get inputs for decoder_with_past_model.onnx
|
|
89
|
+
outputs = self.decoder_model.run(None, inputs)
|
|
90
|
+
|
|
91
|
+
for i in range(int((len(outputs) - 1) / 2)):
|
|
92
|
+
inputs[f"past_key_values.{i}.key"] = outputs[i * 2 + 1]
|
|
93
|
+
inputs[f"past_key_values.{i}.value"] = outputs[i * 2 + 2]
|
|
94
|
+
past_sequence_length = inputs["past_key_values.0.key"].shape[2]
|
|
95
|
+
|
|
96
|
+
inputs["input_ids"] = input_ids[:, -1].unsqueeze(0).detach().cpu().numpy().astype(np.int64)
|
|
97
|
+
attn_mask_torch = torch.ones((self.batch_size, past_sequence_length + 1), dtype=torch.int64)
|
|
98
|
+
inputs["attention_mask"] = attn_mask_torch.detach().cpu().numpy().astype(np.int64)
|
|
99
|
+
inputs["position_ids"] = (
|
|
100
|
+
get_position_ids(attn_mask_torch, use_past_kv=True).detach().cpu().numpy().astype(np.int64)
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
# Yield (inputs, label) tuple for Intel's Neural Compressor:
|
|
104
|
+
# https://github.com/intel/neural-compressor/blob/d4baed9ea11614e1f0dc8a1f4f55b73ed3ed585c/neural_compressor/quantization.py#L55-L62
|
|
105
|
+
yield (inputs, label)
|
|
106
|
+
|
|
107
|
+
except StopIteration:
|
|
108
|
+
return
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License.
|
|
4
|
+
# --------------------------------------------------------------------------
|
|
5
|
+
import os.path
|
|
6
|
+
import sys
|
|
7
|
+
|
|
8
|
+
sys.path.append(os.path.dirname(__file__))
|
|
9
|
+
|
|
10
|
+
transformers_dir = os.path.normpath(os.path.join(os.path.dirname(__file__), "..", ".."))
|
|
11
|
+
if transformers_dir not in sys.path:
|
|
12
|
+
sys.path.append(transformers_dir)
|