onnxruntime-directml 1.20.0__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxruntime/LICENSE +21 -0
- onnxruntime/Privacy.md +21 -0
- onnxruntime/ThirdPartyNotices.txt +6508 -0
- onnxruntime/__init__.py +78 -0
- onnxruntime/backend/__init__.py +6 -0
- onnxruntime/backend/backend.py +174 -0
- onnxruntime/backend/backend_rep.py +53 -0
- onnxruntime/capi/DirectML.dll +0 -0
- onnxruntime/capi/__init__.py +4 -0
- onnxruntime/capi/_ld_preload.py +7 -0
- onnxruntime/capi/_pybind_state.py +33 -0
- onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
- onnxruntime/capi/onnxruntime.dll +0 -0
- onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
- onnxruntime/capi/onnxruntime_inference_collection.py +1108 -0
- onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
- onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
- onnxruntime/capi/onnxruntime_validation.py +150 -0
- onnxruntime/capi/version_info.py +2 -0
- onnxruntime/datasets/__init__.py +17 -0
- onnxruntime/datasets/logreg_iris.onnx +0 -0
- onnxruntime/datasets/mul_1.onnx +0 -0
- onnxruntime/datasets/sigmoid.onnx +13 -0
- onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
- onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
- onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
- onnxruntime/quantization/__init__.py +16 -0
- onnxruntime/quantization/base_quantizer.py +532 -0
- onnxruntime/quantization/calibrate.py +1245 -0
- onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
- onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
- onnxruntime/quantization/execution_providers/qnn/preprocess.py +307 -0
- onnxruntime/quantization/execution_providers/qnn/quant_config.py +387 -0
- onnxruntime/quantization/fusions/__init__.py +3 -0
- onnxruntime/quantization/fusions/fusion.py +311 -0
- onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
- onnxruntime/quantization/fusions/fusion_layernorm.py +135 -0
- onnxruntime/quantization/matmul_4bits_quantizer.py +1480 -0
- onnxruntime/quantization/matmul_bnb4_quantizer.py +240 -0
- onnxruntime/quantization/onnx_model.py +580 -0
- onnxruntime/quantization/onnx_quantizer.py +1008 -0
- onnxruntime/quantization/operators/__init__.py +2 -0
- onnxruntime/quantization/operators/activation.py +119 -0
- onnxruntime/quantization/operators/argmax.py +18 -0
- onnxruntime/quantization/operators/attention.py +73 -0
- onnxruntime/quantization/operators/base_operator.py +26 -0
- onnxruntime/quantization/operators/binary_op.py +72 -0
- onnxruntime/quantization/operators/concat.py +62 -0
- onnxruntime/quantization/operators/conv.py +258 -0
- onnxruntime/quantization/operators/direct_q8.py +78 -0
- onnxruntime/quantization/operators/embed_layernorm.py +121 -0
- onnxruntime/quantization/operators/gather.py +64 -0
- onnxruntime/quantization/operators/gavgpool.py +62 -0
- onnxruntime/quantization/operators/gemm.py +166 -0
- onnxruntime/quantization/operators/lstm.py +117 -0
- onnxruntime/quantization/operators/matmul.py +231 -0
- onnxruntime/quantization/operators/maxpool.py +34 -0
- onnxruntime/quantization/operators/norm.py +40 -0
- onnxruntime/quantization/operators/pad.py +100 -0
- onnxruntime/quantization/operators/pooling.py +67 -0
- onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
- onnxruntime/quantization/operators/resize.py +34 -0
- onnxruntime/quantization/operators/softmax.py +74 -0
- onnxruntime/quantization/operators/split.py +63 -0
- onnxruntime/quantization/operators/where.py +87 -0
- onnxruntime/quantization/preprocess.py +141 -0
- onnxruntime/quantization/qdq_loss_debug.py +389 -0
- onnxruntime/quantization/qdq_quantizer.py +1187 -0
- onnxruntime/quantization/quant_utils.py +891 -0
- onnxruntime/quantization/quantize.py +748 -0
- onnxruntime/quantization/registry.py +106 -0
- onnxruntime/quantization/shape_inference.py +187 -0
- onnxruntime/quantization/tensor_quant_overrides.py +516 -0
- onnxruntime/tools/__init__.py +10 -0
- onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
- onnxruntime/tools/convert_onnx_models_to_ort.py +377 -0
- onnxruntime/tools/file_utils.py +46 -0
- onnxruntime/tools/logger.py +11 -0
- onnxruntime/tools/make_dynamic_shape_fixed.py +72 -0
- onnxruntime/tools/mobile_helpers/__init__.py +0 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +33 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
- onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
- onnxruntime/tools/mobile_helpers/usability_checker.py +739 -0
- onnxruntime/tools/offline_tuning.py +169 -0
- onnxruntime/tools/onnx_model_utils.py +413 -0
- onnxruntime/tools/onnx_randomizer.py +85 -0
- onnxruntime/tools/onnxruntime_test.py +164 -0
- onnxruntime/tools/optimize_onnx_model.py +55 -0
- onnxruntime/tools/ort_format_model/__init__.py +25 -0
- onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +663 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
- onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
- onnxruntime/tools/ort_format_model/types.py +84 -0
- onnxruntime/tools/ort_format_model/utils.py +62 -0
- onnxruntime/tools/pytorch_export_contrib_ops.py +108 -0
- onnxruntime/tools/pytorch_export_helpers.py +131 -0
- onnxruntime/tools/qdq_helpers/__init__.py +0 -0
- onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
- onnxruntime/tools/reduced_build_config_parser.py +202 -0
- onnxruntime/tools/symbolic_shape_infer.py +3016 -0
- onnxruntime/tools/update_onnx_opset.py +31 -0
- onnxruntime/transformers/__init__.py +8 -0
- onnxruntime/transformers/affinity_helper.py +40 -0
- onnxruntime/transformers/benchmark.py +944 -0
- onnxruntime/transformers/benchmark_helper.py +646 -0
- onnxruntime/transformers/bert_perf_test.py +634 -0
- onnxruntime/transformers/bert_test_data.py +642 -0
- onnxruntime/transformers/compare_bert_results.py +246 -0
- onnxruntime/transformers/constants.py +47 -0
- onnxruntime/transformers/convert_generation.py +3124 -0
- onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
- onnxruntime/transformers/convert_to_packing_mode.py +387 -0
- onnxruntime/transformers/dynamo_onnx_helper.py +104 -0
- onnxruntime/transformers/float16.py +501 -0
- onnxruntime/transformers/fusion_attention.py +1235 -0
- onnxruntime/transformers/fusion_attention_clip.py +257 -0
- onnxruntime/transformers/fusion_attention_sam2.py +534 -0
- onnxruntime/transformers/fusion_attention_unet.py +1304 -0
- onnxruntime/transformers/fusion_attention_vae.py +301 -0
- onnxruntime/transformers/fusion_bart_attention.py +640 -0
- onnxruntime/transformers/fusion_base.py +137 -0
- onnxruntime/transformers/fusion_bias_add.py +58 -0
- onnxruntime/transformers/fusion_biasgelu.py +66 -0
- onnxruntime/transformers/fusion_biassplitgelu.py +111 -0
- onnxruntime/transformers/fusion_conformer_attention.py +143 -0
- onnxruntime/transformers/fusion_embedlayer.py +811 -0
- onnxruntime/transformers/fusion_fastgelu.py +360 -0
- onnxruntime/transformers/fusion_gelu.py +259 -0
- onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
- onnxruntime/transformers/fusion_gemmfastgelu.py +122 -0
- onnxruntime/transformers/fusion_gpt_attention.py +546 -0
- onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
- onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
- onnxruntime/transformers/fusion_group_norm.py +179 -0
- onnxruntime/transformers/fusion_layernorm.py +465 -0
- onnxruntime/transformers/fusion_nhwc_conv.py +100 -0
- onnxruntime/transformers/fusion_options.py +340 -0
- onnxruntime/transformers/fusion_qordered_attention.py +421 -0
- onnxruntime/transformers/fusion_qordered_gelu.py +119 -0
- onnxruntime/transformers/fusion_qordered_layernorm.py +123 -0
- onnxruntime/transformers/fusion_qordered_matmul.py +217 -0
- onnxruntime/transformers/fusion_quickgelu.py +74 -0
- onnxruntime/transformers/fusion_reshape.py +173 -0
- onnxruntime/transformers/fusion_rotary_attention.py +1592 -0
- onnxruntime/transformers/fusion_shape.py +110 -0
- onnxruntime/transformers/fusion_simplified_layernorm.py +159 -0
- onnxruntime/transformers/fusion_skip_group_norm.py +255 -0
- onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
- onnxruntime/transformers/fusion_transpose.py +168 -0
- onnxruntime/transformers/fusion_utils.py +307 -0
- onnxruntime/transformers/huggingface_models.py +167 -0
- onnxruntime/transformers/import_utils.py +20 -0
- onnxruntime/transformers/io_binding_helper.py +442 -0
- onnxruntime/transformers/large_model_exporter.py +395 -0
- onnxruntime/transformers/machine_info.py +221 -0
- onnxruntime/transformers/metrics.py +164 -0
- onnxruntime/transformers/models/bart/__init__.py +12 -0
- onnxruntime/transformers/models/bart/export.py +98 -0
- onnxruntime/transformers/models/bert/__init__.py +12 -0
- onnxruntime/transformers/models/bert/eval_squad.py +329 -0
- onnxruntime/transformers/models/gpt2/__init__.py +12 -0
- onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
- onnxruntime/transformers/models/gpt2/convert_to_onnx.py +561 -0
- onnxruntime/transformers/models/gpt2/gpt2_helper.py +1032 -0
- onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
- onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
- onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
- onnxruntime/transformers/models/llama/__init__.py +12 -0
- onnxruntime/transformers/models/llama/benchmark.py +703 -0
- onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
- onnxruntime/transformers/models/llama/benchmark_e2e.py +606 -0
- onnxruntime/transformers/models/llama/convert_to_onnx.py +1027 -0
- onnxruntime/transformers/models/llama/dist_settings.py +57 -0
- onnxruntime/transformers/models/llama/llama_inputs.py +503 -0
- onnxruntime/transformers/models/llama/llama_parity.py +309 -0
- onnxruntime/transformers/models/llama/llama_torch.py +47 -0
- onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
- onnxruntime/transformers/models/longformer/__init__.py +12 -0
- onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
- onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
- onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
- onnxruntime/transformers/models/longformer/longformer_helper.py +77 -0
- onnxruntime/transformers/models/phi2/__init__.py +12 -0
- onnxruntime/transformers/models/phi2/convert_to_onnx.py +576 -0
- onnxruntime/transformers/models/phi2/inference_example.py +414 -0
- onnxruntime/transformers/models/sam2/__init__.py +12 -0
- onnxruntime/transformers/models/sam2/benchmark_sam2.py +625 -0
- onnxruntime/transformers/models/sam2/convert_to_onnx.py +260 -0
- onnxruntime/transformers/models/sam2/image_decoder.py +273 -0
- onnxruntime/transformers/models/sam2/image_encoder.py +186 -0
- onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
- onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
- onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
- onnxruntime/transformers/models/sam2/sam2_demo.py +322 -0
- onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +280 -0
- onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
- onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark.py +1429 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +102 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +268 -0
- onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1319 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1181 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder.py +296 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +388 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
- onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +350 -0
- onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
- onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
- onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
- onnxruntime/transformers/models/t5/__init__.py +12 -0
- onnxruntime/transformers/models/t5/convert_to_onnx.py +278 -0
- onnxruntime/transformers/models/t5/past_helper.py +150 -0
- onnxruntime/transformers/models/t5/t5_decoder.py +438 -0
- onnxruntime/transformers/models/t5/t5_encoder.py +171 -0
- onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +299 -0
- onnxruntime/transformers/models/t5/t5_helper.py +272 -0
- onnxruntime/transformers/models/whisper/__init__.py +12 -0
- onnxruntime/transformers/models/whisper/benchmark.py +610 -0
- onnxruntime/transformers/models/whisper/benchmark_all.py +528 -0
- onnxruntime/transformers/models/whisper/convert_to_onnx.py +536 -0
- onnxruntime/transformers/models/whisper/whisper_chain.py +329 -0
- onnxruntime/transformers/models/whisper/whisper_decoder.py +402 -0
- onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
- onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +306 -0
- onnxruntime/transformers/models/whisper/whisper_helper.py +524 -0
- onnxruntime/transformers/models/whisper/whisper_openai_helper.py +84 -0
- onnxruntime/transformers/onnx_exporter.py +717 -0
- onnxruntime/transformers/onnx_model.py +1569 -0
- onnxruntime/transformers/onnx_model_bart.py +142 -0
- onnxruntime/transformers/onnx_model_bert.py +481 -0
- onnxruntime/transformers/onnx_model_bert_keras.py +475 -0
- onnxruntime/transformers/onnx_model_bert_tf.py +589 -0
- onnxruntime/transformers/onnx_model_clip.py +40 -0
- onnxruntime/transformers/onnx_model_conformer.py +33 -0
- onnxruntime/transformers/onnx_model_gpt2.py +101 -0
- onnxruntime/transformers/onnx_model_phi.py +930 -0
- onnxruntime/transformers/onnx_model_sam2.py +138 -0
- onnxruntime/transformers/onnx_model_t5.py +791 -0
- onnxruntime/transformers/onnx_model_tnlr.py +227 -0
- onnxruntime/transformers/onnx_model_unet.py +259 -0
- onnxruntime/transformers/onnx_model_vae.py +43 -0
- onnxruntime/transformers/onnx_utils.py +55 -0
- onnxruntime/transformers/optimizer.py +612 -0
- onnxruntime/transformers/profiler.py +725 -0
- onnxruntime/transformers/quantize_helper.py +76 -0
- onnxruntime/transformers/shape_infer_helper.py +122 -0
- onnxruntime/transformers/shape_optimizer.py +401 -0
- onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
- onnxruntime_directml-1.20.0.dist-info/METADATA +187 -0
- onnxruntime_directml-1.20.0.dist-info/RECORD +305 -0
- onnxruntime_directml-1.20.0.dist-info/WHEEL +5 -0
- onnxruntime_directml-1.20.0.dist-info/entry_points.txt +2 -0
- onnxruntime_directml-1.20.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,513 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License. See License.txt in the project root for
|
|
4
|
+
# license information.
|
|
5
|
+
# --------------------------------------------------------------------------
|
|
6
|
+
|
|
7
|
+
# This script uses different configurations in mixed precision conversion for GPT-2 model, and
|
|
8
|
+
# measures the inference latency, top 1 match rate (compared to PyTorch FP32 model) and ONNX model size.
|
|
9
|
+
# It outputs a csv file with Mann-Whitney U test and T-Test on each pair of experiments, where
|
|
10
|
+
# pvalue < 0.05 means two experiments have significant difference on top 1 match rate.
|
|
11
|
+
# User could use this script to select the best mixed precision model according to these metrics.
|
|
12
|
+
|
|
13
|
+
import argparse
|
|
14
|
+
import csv
|
|
15
|
+
import datetime
|
|
16
|
+
import json
|
|
17
|
+
import logging
|
|
18
|
+
import os
|
|
19
|
+
|
|
20
|
+
import onnx
|
|
21
|
+
import scipy.stats
|
|
22
|
+
from benchmark_helper import get_ort_environment_variables, setup_logger
|
|
23
|
+
from convert_to_onnx import main
|
|
24
|
+
from gpt2_helper import PRETRAINED_GPT2_MODELS, Gpt2Helper
|
|
25
|
+
from onnx_model import OnnxModel
|
|
26
|
+
|
|
27
|
+
logger = logging.getLogger("")
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def parse_arguments(argv=None):
|
|
31
|
+
parser = argparse.ArgumentParser()
|
|
32
|
+
|
|
33
|
+
parser.add_argument(
|
|
34
|
+
"-m",
|
|
35
|
+
"--model_name_or_path",
|
|
36
|
+
required=True,
|
|
37
|
+
type=str,
|
|
38
|
+
help="Model path, or pretrained model name in the list: " + ", ".join(PRETRAINED_GPT2_MODELS),
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
parser.add_argument(
|
|
42
|
+
"--csv",
|
|
43
|
+
required=False,
|
|
44
|
+
type=str,
|
|
45
|
+
default="gpt2_parity_results.csv",
|
|
46
|
+
help="path of csv file to save the result",
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
parser.add_argument(
|
|
50
|
+
"--test_cases",
|
|
51
|
+
required=False,
|
|
52
|
+
type=int,
|
|
53
|
+
default=500,
|
|
54
|
+
help="number of test cases per run",
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
parser.add_argument("--runs", required=False, type=int, default=40, help="number of repeated runs")
|
|
58
|
+
|
|
59
|
+
parser.add_argument("--use_gpu", required=False, action="store_true", help="use GPU for inference")
|
|
60
|
+
parser.set_defaults(use_gpu=False)
|
|
61
|
+
|
|
62
|
+
parser.add_argument(
|
|
63
|
+
"--all",
|
|
64
|
+
required=False,
|
|
65
|
+
action="store_true",
|
|
66
|
+
help="run all combinations of mixed precision",
|
|
67
|
+
)
|
|
68
|
+
parser.set_defaults(all=False)
|
|
69
|
+
|
|
70
|
+
parser.add_argument("-e", "--use_external_data_format", required=False, action="store_true")
|
|
71
|
+
parser.set_defaults(use_external_data_format=False)
|
|
72
|
+
|
|
73
|
+
parser.add_argument("--verbose", required=False, action="store_true")
|
|
74
|
+
parser.set_defaults(verbose=False)
|
|
75
|
+
|
|
76
|
+
parser.add_argument(
|
|
77
|
+
"--skip_test",
|
|
78
|
+
required=False,
|
|
79
|
+
action="store_true",
|
|
80
|
+
help="do not run test, and only rank experiments based on existing csv file",
|
|
81
|
+
)
|
|
82
|
+
parser.set_defaults(skip_test=False)
|
|
83
|
+
|
|
84
|
+
parser.add_argument(
|
|
85
|
+
"--overwrite",
|
|
86
|
+
required=False,
|
|
87
|
+
action="store_true",
|
|
88
|
+
help="Overwrite existing csv file",
|
|
89
|
+
)
|
|
90
|
+
parser.set_defaults(overwrite=False)
|
|
91
|
+
|
|
92
|
+
args = parser.parse_args(argv)
|
|
93
|
+
|
|
94
|
+
return args
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
class ParityTask:
|
|
98
|
+
def __init__(self, test_cases, total_runs, csv_path):
|
|
99
|
+
self.total_runs = total_runs
|
|
100
|
+
self.test_cases = test_cases
|
|
101
|
+
self.csv_path = csv_path
|
|
102
|
+
self.results = []
|
|
103
|
+
self.run_id = 0
|
|
104
|
+
|
|
105
|
+
def run(self, argv, experiment_name):
|
|
106
|
+
start_time = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
|
|
107
|
+
run_id = f"{start_time}_{self.run_id}"
|
|
108
|
+
self.run_id += 1
|
|
109
|
+
|
|
110
|
+
try:
|
|
111
|
+
result = main(
|
|
112
|
+
[*argv, "-t", f"{self.test_cases}", "-r", f"{self.total_runs}"],
|
|
113
|
+
experiment_name=experiment_name,
|
|
114
|
+
run_id=run_id,
|
|
115
|
+
csv_filename=self.csv_path,
|
|
116
|
+
)
|
|
117
|
+
if result:
|
|
118
|
+
self.results.append(result)
|
|
119
|
+
except Exception:
|
|
120
|
+
logger.exception(f"Failed to run experiment {experiment_name}")
|
|
121
|
+
result = None
|
|
122
|
+
|
|
123
|
+
return result
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def load_results_from_csv(csv_path):
|
|
127
|
+
rows = []
|
|
128
|
+
import csv
|
|
129
|
+
|
|
130
|
+
with open(csv_path, newline="") as csvfile:
|
|
131
|
+
reader = csv.DictReader(csvfile)
|
|
132
|
+
for row in reader:
|
|
133
|
+
rows.append(row) # noqa: PERF402
|
|
134
|
+
return rows
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def get_latency(row):
|
|
138
|
+
for name in row:
|
|
139
|
+
if name.startswith("average_latency(batch_size="):
|
|
140
|
+
return float(row[name])
|
|
141
|
+
|
|
142
|
+
raise RuntimeError("Failed to get average_latency from output")
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def score(row):
|
|
146
|
+
"""Scoring function based on 3 metrics. The larger score is better."""
|
|
147
|
+
latency_in_ms = get_latency(row)
|
|
148
|
+
top1_match_rate = float(row["top1_match_rate"])
|
|
149
|
+
onnx_size_in_MB = float(row["onnx_size_in_MB"]) # noqa: N806
|
|
150
|
+
# A simple scoring function: cost of 0.1ms latency ~ 0.1% match rate ~ 100MB size
|
|
151
|
+
return top1_match_rate * 1000 - latency_in_ms * 10 - onnx_size_in_MB / 100
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def print_wins(wins, rows, test_name):
|
|
155
|
+
print()
|
|
156
|
+
print("*" * 10)
|
|
157
|
+
|
|
158
|
+
row_map = {}
|
|
159
|
+
for row in rows:
|
|
160
|
+
row_map[row["run_id"]] = row
|
|
161
|
+
|
|
162
|
+
sorted_wins = dict(
|
|
163
|
+
sorted(
|
|
164
|
+
wins.items(),
|
|
165
|
+
key=lambda item: (item[1], score(row_map[item[0]])),
|
|
166
|
+
reverse=True,
|
|
167
|
+
)
|
|
168
|
+
)
|
|
169
|
+
logger.debug(f"{test_name} Wins:{sorted_wins}")
|
|
170
|
+
logger.info(f"Based on {test_name} wins and a scoring function, the ranking:")
|
|
171
|
+
|
|
172
|
+
rank = 0
|
|
173
|
+
previous_value = -1
|
|
174
|
+
for count, (key, value) in enumerate(sorted_wins.items()):
|
|
175
|
+
if value != previous_value:
|
|
176
|
+
rank = count
|
|
177
|
+
previous_value = value
|
|
178
|
+
|
|
179
|
+
for row in rows:
|
|
180
|
+
if row["run_id"] == key:
|
|
181
|
+
logger.info(
|
|
182
|
+
"{:02d}: WINs={:02d}, run_id={}, latency={:5.2f}, top1_match={:.4f}, size={}_MB, experiment={}, {}".format( # noqa: G001
|
|
183
|
+
rank,
|
|
184
|
+
value,
|
|
185
|
+
key,
|
|
186
|
+
get_latency(row),
|
|
187
|
+
float(row["top1_match_rate"]),
|
|
188
|
+
row["onnx_size_in_MB"],
|
|
189
|
+
row["experiment"],
|
|
190
|
+
get_ort_environment_variables(),
|
|
191
|
+
)
|
|
192
|
+
)
|
|
193
|
+
break
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
def run_significance_test(rows, output_csv_path):
|
|
197
|
+
"""Run U test and T test."""
|
|
198
|
+
utest_wins = {}
|
|
199
|
+
ttest_wins = {}
|
|
200
|
+
for row in rows:
|
|
201
|
+
run_id = row["run_id"]
|
|
202
|
+
utest_wins[run_id] = 0
|
|
203
|
+
ttest_wins[run_id] = 0
|
|
204
|
+
|
|
205
|
+
with open(output_csv_path, "w", newline="") as csvfile:
|
|
206
|
+
column_names = [
|
|
207
|
+
"model_name",
|
|
208
|
+
"run_id_1",
|
|
209
|
+
"experiment_1",
|
|
210
|
+
"top1_match_rate_1",
|
|
211
|
+
"run_id_2",
|
|
212
|
+
"experiment_2",
|
|
213
|
+
"top1_match_rate_2",
|
|
214
|
+
"U_statistic",
|
|
215
|
+
"U_pvalue",
|
|
216
|
+
"T_statistic",
|
|
217
|
+
"T_pvalue",
|
|
218
|
+
]
|
|
219
|
+
|
|
220
|
+
writer = csv.DictWriter(csvfile, fieldnames=column_names)
|
|
221
|
+
writer.writeheader()
|
|
222
|
+
|
|
223
|
+
required_match_columns = ["model_name", "test_cases", "runs"]
|
|
224
|
+
num_results = len(rows)
|
|
225
|
+
for i in range(num_results - 1):
|
|
226
|
+
result1 = rows[i]
|
|
227
|
+
|
|
228
|
+
if isinstance(result1["top1_match_rate_per_run"], str):
|
|
229
|
+
a = json.loads(result1["top1_match_rate_per_run"])
|
|
230
|
+
else:
|
|
231
|
+
a = result1["top1_match_rate_per_run"]
|
|
232
|
+
|
|
233
|
+
for j in range(i + 1, num_results, 1):
|
|
234
|
+
result2 = rows[j]
|
|
235
|
+
|
|
236
|
+
all_matched = True
|
|
237
|
+
for column in required_match_columns:
|
|
238
|
+
if result1[column] != result2[column]:
|
|
239
|
+
all_matched = False
|
|
240
|
+
break
|
|
241
|
+
if not all_matched:
|
|
242
|
+
continue
|
|
243
|
+
|
|
244
|
+
if isinstance(result2["top1_match_rate_per_run"], str):
|
|
245
|
+
b = json.loads(result2["top1_match_rate_per_run"])
|
|
246
|
+
else:
|
|
247
|
+
b = result2["top1_match_rate_per_run"]
|
|
248
|
+
|
|
249
|
+
try:
|
|
250
|
+
utest_statistic, utest_pvalue = scipy.stats.mannwhitneyu(
|
|
251
|
+
a, b, use_continuity=True, alternative="two-sided"
|
|
252
|
+
) # TODO: shall we use one-sided: less or greater according to "top1_match_rate"
|
|
253
|
+
except ValueError: # ValueError: All numbers are identical in mannwhitneyu
|
|
254
|
+
utest_statistic = None
|
|
255
|
+
utest_pvalue = None
|
|
256
|
+
ttest_statistic, ttest_pvalue = scipy.stats.ttest_ind(a, b, axis=None, equal_var=True)
|
|
257
|
+
|
|
258
|
+
if utest_pvalue is not None and utest_pvalue < 0.05:
|
|
259
|
+
if float(result1["top1_match_rate"]) > float(result2["top1_match_rate"]):
|
|
260
|
+
utest_wins[result1["run_id"]] += 1
|
|
261
|
+
else:
|
|
262
|
+
utest_wins[result2["run_id"]] += 1
|
|
263
|
+
|
|
264
|
+
if ttest_pvalue < 0.05:
|
|
265
|
+
if float(result1["top1_match_rate"]) > float(result2["top1_match_rate"]):
|
|
266
|
+
ttest_wins[result1["run_id"]] += 1
|
|
267
|
+
else:
|
|
268
|
+
ttest_wins[result2["run_id"]] += 1
|
|
269
|
+
|
|
270
|
+
row = {
|
|
271
|
+
"model_name": result1["model_name"],
|
|
272
|
+
"run_id_1": result1["run_id"],
|
|
273
|
+
"experiment_1": result1["experiment"],
|
|
274
|
+
"top1_match_rate_1": float(result1["top1_match_rate"]),
|
|
275
|
+
"run_id_2": result2["run_id"],
|
|
276
|
+
"experiment_2": result2["experiment"],
|
|
277
|
+
"top1_match_rate_2": float(result2["top1_match_rate"]),
|
|
278
|
+
"U_statistic": utest_statistic,
|
|
279
|
+
"U_pvalue": utest_pvalue,
|
|
280
|
+
"T_statistic": ttest_statistic,
|
|
281
|
+
"T_pvalue": ttest_pvalue,
|
|
282
|
+
}
|
|
283
|
+
|
|
284
|
+
writer.writerow(row)
|
|
285
|
+
logger.info(f"U-Test and T-Test results are output to {output_csv_path}")
|
|
286
|
+
print_wins(utest_wins, rows, "U-Test")
|
|
287
|
+
print_wins(ttest_wins, rows, "T-Test")
|
|
288
|
+
|
|
289
|
+
|
|
290
|
+
def get_last_matmul_node_name(raw_onnx_model: str):
|
|
291
|
+
model = onnx.load(raw_onnx_model)
|
|
292
|
+
onnx_model = OnnxModel(model)
|
|
293
|
+
output_name_to_node = onnx_model.output_name_to_node()
|
|
294
|
+
|
|
295
|
+
assert model.graph.output[0].name in output_name_to_node
|
|
296
|
+
node = output_name_to_node[model.graph.output[0].name]
|
|
297
|
+
if node.op_type == "MatMul":
|
|
298
|
+
logger.info(f"Found last MatMul node for logits: {node.name}")
|
|
299
|
+
return node.name
|
|
300
|
+
|
|
301
|
+
logger.warning(f"Failed to find MatMul node for logits. Found {node.op_type} of node {node.name}")
|
|
302
|
+
return None
|
|
303
|
+
|
|
304
|
+
|
|
305
|
+
def get_mixed_precision_parameters(args, last_matmul_node_name, op_block_list):
|
|
306
|
+
model = args.model_name_or_path
|
|
307
|
+
parameters = f"-m {model} -o --use_gpu -p fp16".split()
|
|
308
|
+
if args.use_external_data_format:
|
|
309
|
+
parameters.append("--use_external_data_format")
|
|
310
|
+
parameters += [
|
|
311
|
+
"--io_block_list",
|
|
312
|
+
"logits",
|
|
313
|
+
"--node_block_list",
|
|
314
|
+
last_matmul_node_name,
|
|
315
|
+
]
|
|
316
|
+
|
|
317
|
+
if op_block_list:
|
|
318
|
+
parameters.extend(["--op_block_list", *op_block_list])
|
|
319
|
+
|
|
320
|
+
return parameters
|
|
321
|
+
|
|
322
|
+
|
|
323
|
+
def run_candidate(
|
|
324
|
+
task: ParityTask,
|
|
325
|
+
args,
|
|
326
|
+
last_matmul_node_name,
|
|
327
|
+
op_block_list=["FastGelu", "LayerNormalization"], # noqa: B006
|
|
328
|
+
):
|
|
329
|
+
parameters = get_mixed_precision_parameters(args, last_matmul_node_name, op_block_list)
|
|
330
|
+
op_block_list_str = ",".join(sorted(op_block_list))
|
|
331
|
+
|
|
332
|
+
if op_block_list:
|
|
333
|
+
name = f"Mixed precision baseline + {op_block_list_str} in FP32"
|
|
334
|
+
else:
|
|
335
|
+
name = f"Mixed precision baseline (logits output and last MatMul node {last_matmul_node_name} in FP32)"
|
|
336
|
+
|
|
337
|
+
env_vars = get_ort_environment_variables()
|
|
338
|
+
if env_vars:
|
|
339
|
+
name = name + f" ({env_vars})"
|
|
340
|
+
|
|
341
|
+
task.run(parameters, name)
|
|
342
|
+
|
|
343
|
+
|
|
344
|
+
def get_baselines(args):
|
|
345
|
+
model = args.model_name_or_path
|
|
346
|
+
fp32_baseline = f"-m {model} -o -p fp32".split()
|
|
347
|
+
if args.use_gpu:
|
|
348
|
+
fp32_baseline.append("--use_gpu")
|
|
349
|
+
if args.use_external_data_format:
|
|
350
|
+
fp32_baseline.append("--use_external_data_format")
|
|
351
|
+
|
|
352
|
+
fp16_baseline = f"-m {model} -o --use_gpu -p fp16".split()
|
|
353
|
+
if args.use_external_data_format:
|
|
354
|
+
fp16_baseline.append("--use_external_data_format")
|
|
355
|
+
|
|
356
|
+
return fp32_baseline, fp16_baseline
|
|
357
|
+
|
|
358
|
+
|
|
359
|
+
def run_tuning_step0(task, fp16_baseline, all_ops, optimized_ops):
|
|
360
|
+
"""Step 0 is to check which operator in FP16 causes most loss"""
|
|
361
|
+
fp32_logits = ["--io_block_list", "logits"]
|
|
362
|
+
task.run(fp16_baseline + fp32_logits, "FP16 except logits")
|
|
363
|
+
|
|
364
|
+
fp32_io = ["--keep_io_types"]
|
|
365
|
+
task.run(fp16_baseline + fp32_io, "Graph I/O FP32, Other FP16")
|
|
366
|
+
|
|
367
|
+
# Only weights in FP16
|
|
368
|
+
task.run(
|
|
369
|
+
fp16_baseline + fp32_io + ["--op_block_list"] + [o for o in all_ops] + ["--force_fp16_initializers"],
|
|
370
|
+
"FP32 except weights in FP16",
|
|
371
|
+
)
|
|
372
|
+
|
|
373
|
+
optimized_ops_results = []
|
|
374
|
+
op_list = optimized_ops
|
|
375
|
+
for op in op_list:
|
|
376
|
+
op_block_list = ["--op_block_list"] + [o for o in op_list if o != op]
|
|
377
|
+
result = task.run(fp16_baseline + fp32_io + op_block_list, f"FP32 except {op} in FP16")
|
|
378
|
+
if result:
|
|
379
|
+
optimized_ops_results.append(result)
|
|
380
|
+
|
|
381
|
+
# Check which optimized operator causes the most loss in precision
|
|
382
|
+
min_result = min(optimized_ops_results, key=lambda y: y["top1_match_rate"])
|
|
383
|
+
print("step 0: optimized operator causes the most loss in precision", min_result)
|
|
384
|
+
|
|
385
|
+
|
|
386
|
+
def run_tuning_step1(task, mixed_precision_baseline, optimized_ops):
|
|
387
|
+
"""Step 1 is to figure out which optimized operator in FP32 could benefit most"""
|
|
388
|
+
for op in optimized_ops:
|
|
389
|
+
op_block_list = ["--op_block_list", op]
|
|
390
|
+
task.run(
|
|
391
|
+
mixed_precision_baseline + op_block_list,
|
|
392
|
+
f"Mixed precision baseline + {op} in FP32",
|
|
393
|
+
)
|
|
394
|
+
|
|
395
|
+
|
|
396
|
+
def run_tuning_step2(task, mixed_precision_baseline, optimized_ops):
|
|
397
|
+
"""Assumed that you have run step 0 and 1 to figure out that Logits FP32 and some operators shall be in FP32,
|
|
398
|
+
This step will try add one more operator.
|
|
399
|
+
"""
|
|
400
|
+
candidate_fp32_ops = ["FastGelu", "LayerNormalization", "SkipLayerNormalization"]
|
|
401
|
+
fp32_ops = [x for x in candidate_fp32_ops if x in optimized_ops]
|
|
402
|
+
for op in optimized_ops:
|
|
403
|
+
if op not in fp32_ops:
|
|
404
|
+
op_block_list = [*fp32_ops, op]
|
|
405
|
+
task.run(
|
|
406
|
+
[*mixed_precision_baseline, "--op_block_list", *op_block_list],
|
|
407
|
+
"Mixed precision baseline + {},{} in FP32".format(",".join(fp32_ops), op),
|
|
408
|
+
)
|
|
409
|
+
|
|
410
|
+
|
|
411
|
+
def run_parity(task: ParityTask, args):
|
|
412
|
+
onnx_model_paths = Gpt2Helper.get_onnx_paths(
|
|
413
|
+
"onnx_models",
|
|
414
|
+
args.model_name_or_path,
|
|
415
|
+
new_folder=args.use_external_data_format,
|
|
416
|
+
remove_existing=[],
|
|
417
|
+
)
|
|
418
|
+
|
|
419
|
+
fp32_baseline, fp16_baseline = get_baselines(args)
|
|
420
|
+
|
|
421
|
+
result = task.run(fp32_baseline, "FP32 baseline")
|
|
422
|
+
|
|
423
|
+
optimized_ops = []
|
|
424
|
+
if result and ("optimized_operators" in result) and result["optimized_operators"]:
|
|
425
|
+
optimized_ops = result["optimized_operators"].split(",")
|
|
426
|
+
else:
|
|
427
|
+
raise RuntimeError("Failed to get optimized operators")
|
|
428
|
+
|
|
429
|
+
all_ops = []
|
|
430
|
+
if result and ("operators" in result) and result["operators"]:
|
|
431
|
+
all_ops = result["operators"].split(",")
|
|
432
|
+
else:
|
|
433
|
+
raise RuntimeError("Failed to get operators")
|
|
434
|
+
|
|
435
|
+
# The following tests for fp16 requires GPU
|
|
436
|
+
if not args.use_gpu:
|
|
437
|
+
logger.info("skip mixed precision since --use_gpu is not specified")
|
|
438
|
+
return
|
|
439
|
+
|
|
440
|
+
task.run(fp16_baseline, "FP16 baseline")
|
|
441
|
+
|
|
442
|
+
last_matmul_node_name = get_last_matmul_node_name(onnx_model_paths["raw"])
|
|
443
|
+
|
|
444
|
+
# Mixed precision baseline
|
|
445
|
+
run_candidate(task, args, last_matmul_node_name, op_block_list=[])
|
|
446
|
+
|
|
447
|
+
def get_fp32_ops(x):
|
|
448
|
+
return [op for op in x if op in all_ops]
|
|
449
|
+
|
|
450
|
+
if args.all:
|
|
451
|
+
run_tuning_step0(task, fp16_baseline, all_ops, optimized_ops)
|
|
452
|
+
mixed_precision_baseline = get_mixed_precision_parameters(args, last_matmul_node_name, op_block_list=[])
|
|
453
|
+
run_tuning_step1(task, mixed_precision_baseline, optimized_ops)
|
|
454
|
+
run_tuning_step2(task, mixed_precision_baseline, optimized_ops)
|
|
455
|
+
else:
|
|
456
|
+
run_candidate(
|
|
457
|
+
task,
|
|
458
|
+
args,
|
|
459
|
+
last_matmul_node_name,
|
|
460
|
+
op_block_list=get_fp32_ops(["SkipLayerNormalization", "LayerNormalization", "Add"]),
|
|
461
|
+
)
|
|
462
|
+
run_candidate(task, args, last_matmul_node_name, op_block_list=["FastGelu"])
|
|
463
|
+
|
|
464
|
+
# Run a few good candidates
|
|
465
|
+
run_candidate(
|
|
466
|
+
task,
|
|
467
|
+
args,
|
|
468
|
+
last_matmul_node_name,
|
|
469
|
+
op_block_list=get_fp32_ops(["FastGelu", "SkipLayerNormalization", "LayerNormalization", "Add"]),
|
|
470
|
+
)
|
|
471
|
+
run_candidate(
|
|
472
|
+
task,
|
|
473
|
+
args,
|
|
474
|
+
last_matmul_node_name,
|
|
475
|
+
op_block_list=get_fp32_ops(
|
|
476
|
+
["FastGelu", "EmbedLayerNormalization", "SkipLayerNormalization", "LayerNormalization", "Add"]
|
|
477
|
+
),
|
|
478
|
+
)
|
|
479
|
+
|
|
480
|
+
|
|
481
|
+
if __name__ == "__main__":
|
|
482
|
+
args = parse_arguments()
|
|
483
|
+
setup_logger(args.verbose)
|
|
484
|
+
|
|
485
|
+
if args.test_cases < 100 or args.runs < 20 or args.test_cases * args.runs < 10000:
|
|
486
|
+
logger.warning(
|
|
487
|
+
"Not enough test cases or runs to get stable results or test significance. "
|
|
488
|
+
"Recommend test_cases >= 100, runs >= 20, test_cases * runs >= 10000."
|
|
489
|
+
)
|
|
490
|
+
|
|
491
|
+
if os.path.exists(args.csv) and not args.skip_test:
|
|
492
|
+
if not args.overwrite:
|
|
493
|
+
raise RuntimeError(
|
|
494
|
+
f"Output file {args.csv} existed. Please remove the file, or use either --skip_test or --overwrite."
|
|
495
|
+
)
|
|
496
|
+
else:
|
|
497
|
+
logger.info("Remove existing file %s since --overwrite is specified", args.csv)
|
|
498
|
+
os.remove(args.csv)
|
|
499
|
+
|
|
500
|
+
task = ParityTask(args.test_cases, args.runs, args.csv)
|
|
501
|
+
|
|
502
|
+
if not args.skip_test:
|
|
503
|
+
run_parity(task, args)
|
|
504
|
+
|
|
505
|
+
try:
|
|
506
|
+
rows = load_results_from_csv(task.csv_path)
|
|
507
|
+
except Exception:
|
|
508
|
+
logger.exception(f"Failed to load csv {task.csv_path}")
|
|
509
|
+
rows = task.results
|
|
510
|
+
|
|
511
|
+
logger.info("Start running significance tests...")
|
|
512
|
+
summary_csv = task.csv_path.replace(".csv", ".stats.csv")
|
|
513
|
+
run_significance_test(rows, summary_csv)
|