onnxruntime-directml 1.20.0__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxruntime/LICENSE +21 -0
- onnxruntime/Privacy.md +21 -0
- onnxruntime/ThirdPartyNotices.txt +6508 -0
- onnxruntime/__init__.py +78 -0
- onnxruntime/backend/__init__.py +6 -0
- onnxruntime/backend/backend.py +174 -0
- onnxruntime/backend/backend_rep.py +53 -0
- onnxruntime/capi/DirectML.dll +0 -0
- onnxruntime/capi/__init__.py +4 -0
- onnxruntime/capi/_ld_preload.py +7 -0
- onnxruntime/capi/_pybind_state.py +33 -0
- onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
- onnxruntime/capi/onnxruntime.dll +0 -0
- onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
- onnxruntime/capi/onnxruntime_inference_collection.py +1108 -0
- onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
- onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
- onnxruntime/capi/onnxruntime_validation.py +150 -0
- onnxruntime/capi/version_info.py +2 -0
- onnxruntime/datasets/__init__.py +17 -0
- onnxruntime/datasets/logreg_iris.onnx +0 -0
- onnxruntime/datasets/mul_1.onnx +0 -0
- onnxruntime/datasets/sigmoid.onnx +13 -0
- onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
- onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
- onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
- onnxruntime/quantization/__init__.py +16 -0
- onnxruntime/quantization/base_quantizer.py +532 -0
- onnxruntime/quantization/calibrate.py +1245 -0
- onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
- onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
- onnxruntime/quantization/execution_providers/qnn/preprocess.py +307 -0
- onnxruntime/quantization/execution_providers/qnn/quant_config.py +387 -0
- onnxruntime/quantization/fusions/__init__.py +3 -0
- onnxruntime/quantization/fusions/fusion.py +311 -0
- onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
- onnxruntime/quantization/fusions/fusion_layernorm.py +135 -0
- onnxruntime/quantization/matmul_4bits_quantizer.py +1480 -0
- onnxruntime/quantization/matmul_bnb4_quantizer.py +240 -0
- onnxruntime/quantization/onnx_model.py +580 -0
- onnxruntime/quantization/onnx_quantizer.py +1008 -0
- onnxruntime/quantization/operators/__init__.py +2 -0
- onnxruntime/quantization/operators/activation.py +119 -0
- onnxruntime/quantization/operators/argmax.py +18 -0
- onnxruntime/quantization/operators/attention.py +73 -0
- onnxruntime/quantization/operators/base_operator.py +26 -0
- onnxruntime/quantization/operators/binary_op.py +72 -0
- onnxruntime/quantization/operators/concat.py +62 -0
- onnxruntime/quantization/operators/conv.py +258 -0
- onnxruntime/quantization/operators/direct_q8.py +78 -0
- onnxruntime/quantization/operators/embed_layernorm.py +121 -0
- onnxruntime/quantization/operators/gather.py +64 -0
- onnxruntime/quantization/operators/gavgpool.py +62 -0
- onnxruntime/quantization/operators/gemm.py +166 -0
- onnxruntime/quantization/operators/lstm.py +117 -0
- onnxruntime/quantization/operators/matmul.py +231 -0
- onnxruntime/quantization/operators/maxpool.py +34 -0
- onnxruntime/quantization/operators/norm.py +40 -0
- onnxruntime/quantization/operators/pad.py +100 -0
- onnxruntime/quantization/operators/pooling.py +67 -0
- onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
- onnxruntime/quantization/operators/resize.py +34 -0
- onnxruntime/quantization/operators/softmax.py +74 -0
- onnxruntime/quantization/operators/split.py +63 -0
- onnxruntime/quantization/operators/where.py +87 -0
- onnxruntime/quantization/preprocess.py +141 -0
- onnxruntime/quantization/qdq_loss_debug.py +389 -0
- onnxruntime/quantization/qdq_quantizer.py +1187 -0
- onnxruntime/quantization/quant_utils.py +891 -0
- onnxruntime/quantization/quantize.py +748 -0
- onnxruntime/quantization/registry.py +106 -0
- onnxruntime/quantization/shape_inference.py +187 -0
- onnxruntime/quantization/tensor_quant_overrides.py +516 -0
- onnxruntime/tools/__init__.py +10 -0
- onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
- onnxruntime/tools/convert_onnx_models_to_ort.py +377 -0
- onnxruntime/tools/file_utils.py +46 -0
- onnxruntime/tools/logger.py +11 -0
- onnxruntime/tools/make_dynamic_shape_fixed.py +72 -0
- onnxruntime/tools/mobile_helpers/__init__.py +0 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +33 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
- onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
- onnxruntime/tools/mobile_helpers/usability_checker.py +739 -0
- onnxruntime/tools/offline_tuning.py +169 -0
- onnxruntime/tools/onnx_model_utils.py +413 -0
- onnxruntime/tools/onnx_randomizer.py +85 -0
- onnxruntime/tools/onnxruntime_test.py +164 -0
- onnxruntime/tools/optimize_onnx_model.py +55 -0
- onnxruntime/tools/ort_format_model/__init__.py +25 -0
- onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +663 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
- onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
- onnxruntime/tools/ort_format_model/types.py +84 -0
- onnxruntime/tools/ort_format_model/utils.py +62 -0
- onnxruntime/tools/pytorch_export_contrib_ops.py +108 -0
- onnxruntime/tools/pytorch_export_helpers.py +131 -0
- onnxruntime/tools/qdq_helpers/__init__.py +0 -0
- onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
- onnxruntime/tools/reduced_build_config_parser.py +202 -0
- onnxruntime/tools/symbolic_shape_infer.py +3016 -0
- onnxruntime/tools/update_onnx_opset.py +31 -0
- onnxruntime/transformers/__init__.py +8 -0
- onnxruntime/transformers/affinity_helper.py +40 -0
- onnxruntime/transformers/benchmark.py +944 -0
- onnxruntime/transformers/benchmark_helper.py +646 -0
- onnxruntime/transformers/bert_perf_test.py +634 -0
- onnxruntime/transformers/bert_test_data.py +642 -0
- onnxruntime/transformers/compare_bert_results.py +246 -0
- onnxruntime/transformers/constants.py +47 -0
- onnxruntime/transformers/convert_generation.py +3124 -0
- onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
- onnxruntime/transformers/convert_to_packing_mode.py +387 -0
- onnxruntime/transformers/dynamo_onnx_helper.py +104 -0
- onnxruntime/transformers/float16.py +501 -0
- onnxruntime/transformers/fusion_attention.py +1235 -0
- onnxruntime/transformers/fusion_attention_clip.py +257 -0
- onnxruntime/transformers/fusion_attention_sam2.py +534 -0
- onnxruntime/transformers/fusion_attention_unet.py +1304 -0
- onnxruntime/transformers/fusion_attention_vae.py +301 -0
- onnxruntime/transformers/fusion_bart_attention.py +640 -0
- onnxruntime/transformers/fusion_base.py +137 -0
- onnxruntime/transformers/fusion_bias_add.py +58 -0
- onnxruntime/transformers/fusion_biasgelu.py +66 -0
- onnxruntime/transformers/fusion_biassplitgelu.py +111 -0
- onnxruntime/transformers/fusion_conformer_attention.py +143 -0
- onnxruntime/transformers/fusion_embedlayer.py +811 -0
- onnxruntime/transformers/fusion_fastgelu.py +360 -0
- onnxruntime/transformers/fusion_gelu.py +259 -0
- onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
- onnxruntime/transformers/fusion_gemmfastgelu.py +122 -0
- onnxruntime/transformers/fusion_gpt_attention.py +546 -0
- onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
- onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
- onnxruntime/transformers/fusion_group_norm.py +179 -0
- onnxruntime/transformers/fusion_layernorm.py +465 -0
- onnxruntime/transformers/fusion_nhwc_conv.py +100 -0
- onnxruntime/transformers/fusion_options.py +340 -0
- onnxruntime/transformers/fusion_qordered_attention.py +421 -0
- onnxruntime/transformers/fusion_qordered_gelu.py +119 -0
- onnxruntime/transformers/fusion_qordered_layernorm.py +123 -0
- onnxruntime/transformers/fusion_qordered_matmul.py +217 -0
- onnxruntime/transformers/fusion_quickgelu.py +74 -0
- onnxruntime/transformers/fusion_reshape.py +173 -0
- onnxruntime/transformers/fusion_rotary_attention.py +1592 -0
- onnxruntime/transformers/fusion_shape.py +110 -0
- onnxruntime/transformers/fusion_simplified_layernorm.py +159 -0
- onnxruntime/transformers/fusion_skip_group_norm.py +255 -0
- onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
- onnxruntime/transformers/fusion_transpose.py +168 -0
- onnxruntime/transformers/fusion_utils.py +307 -0
- onnxruntime/transformers/huggingface_models.py +167 -0
- onnxruntime/transformers/import_utils.py +20 -0
- onnxruntime/transformers/io_binding_helper.py +442 -0
- onnxruntime/transformers/large_model_exporter.py +395 -0
- onnxruntime/transformers/machine_info.py +221 -0
- onnxruntime/transformers/metrics.py +164 -0
- onnxruntime/transformers/models/bart/__init__.py +12 -0
- onnxruntime/transformers/models/bart/export.py +98 -0
- onnxruntime/transformers/models/bert/__init__.py +12 -0
- onnxruntime/transformers/models/bert/eval_squad.py +329 -0
- onnxruntime/transformers/models/gpt2/__init__.py +12 -0
- onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
- onnxruntime/transformers/models/gpt2/convert_to_onnx.py +561 -0
- onnxruntime/transformers/models/gpt2/gpt2_helper.py +1032 -0
- onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
- onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
- onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
- onnxruntime/transformers/models/llama/__init__.py +12 -0
- onnxruntime/transformers/models/llama/benchmark.py +703 -0
- onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
- onnxruntime/transformers/models/llama/benchmark_e2e.py +606 -0
- onnxruntime/transformers/models/llama/convert_to_onnx.py +1027 -0
- onnxruntime/transformers/models/llama/dist_settings.py +57 -0
- onnxruntime/transformers/models/llama/llama_inputs.py +503 -0
- onnxruntime/transformers/models/llama/llama_parity.py +309 -0
- onnxruntime/transformers/models/llama/llama_torch.py +47 -0
- onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
- onnxruntime/transformers/models/longformer/__init__.py +12 -0
- onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
- onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
- onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
- onnxruntime/transformers/models/longformer/longformer_helper.py +77 -0
- onnxruntime/transformers/models/phi2/__init__.py +12 -0
- onnxruntime/transformers/models/phi2/convert_to_onnx.py +576 -0
- onnxruntime/transformers/models/phi2/inference_example.py +414 -0
- onnxruntime/transformers/models/sam2/__init__.py +12 -0
- onnxruntime/transformers/models/sam2/benchmark_sam2.py +625 -0
- onnxruntime/transformers/models/sam2/convert_to_onnx.py +260 -0
- onnxruntime/transformers/models/sam2/image_decoder.py +273 -0
- onnxruntime/transformers/models/sam2/image_encoder.py +186 -0
- onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
- onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
- onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
- onnxruntime/transformers/models/sam2/sam2_demo.py +322 -0
- onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +280 -0
- onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
- onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark.py +1429 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +102 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +268 -0
- onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1319 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1181 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder.py +296 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +388 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
- onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +350 -0
- onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
- onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
- onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
- onnxruntime/transformers/models/t5/__init__.py +12 -0
- onnxruntime/transformers/models/t5/convert_to_onnx.py +278 -0
- onnxruntime/transformers/models/t5/past_helper.py +150 -0
- onnxruntime/transformers/models/t5/t5_decoder.py +438 -0
- onnxruntime/transformers/models/t5/t5_encoder.py +171 -0
- onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +299 -0
- onnxruntime/transformers/models/t5/t5_helper.py +272 -0
- onnxruntime/transformers/models/whisper/__init__.py +12 -0
- onnxruntime/transformers/models/whisper/benchmark.py +610 -0
- onnxruntime/transformers/models/whisper/benchmark_all.py +528 -0
- onnxruntime/transformers/models/whisper/convert_to_onnx.py +536 -0
- onnxruntime/transformers/models/whisper/whisper_chain.py +329 -0
- onnxruntime/transformers/models/whisper/whisper_decoder.py +402 -0
- onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
- onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +306 -0
- onnxruntime/transformers/models/whisper/whisper_helper.py +524 -0
- onnxruntime/transformers/models/whisper/whisper_openai_helper.py +84 -0
- onnxruntime/transformers/onnx_exporter.py +717 -0
- onnxruntime/transformers/onnx_model.py +1569 -0
- onnxruntime/transformers/onnx_model_bart.py +142 -0
- onnxruntime/transformers/onnx_model_bert.py +481 -0
- onnxruntime/transformers/onnx_model_bert_keras.py +475 -0
- onnxruntime/transformers/onnx_model_bert_tf.py +589 -0
- onnxruntime/transformers/onnx_model_clip.py +40 -0
- onnxruntime/transformers/onnx_model_conformer.py +33 -0
- onnxruntime/transformers/onnx_model_gpt2.py +101 -0
- onnxruntime/transformers/onnx_model_phi.py +930 -0
- onnxruntime/transformers/onnx_model_sam2.py +138 -0
- onnxruntime/transformers/onnx_model_t5.py +791 -0
- onnxruntime/transformers/onnx_model_tnlr.py +227 -0
- onnxruntime/transformers/onnx_model_unet.py +259 -0
- onnxruntime/transformers/onnx_model_vae.py +43 -0
- onnxruntime/transformers/onnx_utils.py +55 -0
- onnxruntime/transformers/optimizer.py +612 -0
- onnxruntime/transformers/profiler.py +725 -0
- onnxruntime/transformers/quantize_helper.py +76 -0
- onnxruntime/transformers/shape_infer_helper.py +122 -0
- onnxruntime/transformers/shape_optimizer.py +401 -0
- onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
- onnxruntime_directml-1.20.0.dist-info/METADATA +187 -0
- onnxruntime_directml-1.20.0.dist-info/RECORD +305 -0
- onnxruntime_directml-1.20.0.dist-info/WHEEL +5 -0
- onnxruntime_directml-1.20.0.dist-info/entry_points.txt +2 -0
- onnxruntime_directml-1.20.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,395 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License.
|
|
4
|
+
# --------------------------------------------------------------------------
|
|
5
|
+
|
|
6
|
+
"""
|
|
7
|
+
Export LLM to onnx
|
|
8
|
+
"""
|
|
9
|
+
import argparse
|
|
10
|
+
import inspect
|
|
11
|
+
import math
|
|
12
|
+
import os
|
|
13
|
+
import tempfile
|
|
14
|
+
from pathlib import Path
|
|
15
|
+
from typing import Optional
|
|
16
|
+
|
|
17
|
+
import onnx
|
|
18
|
+
import torch
|
|
19
|
+
import transformers
|
|
20
|
+
from torch import nn
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def disable_huggingface_init():
|
|
24
|
+
"""do not init model twice as it slow initialization"""
|
|
25
|
+
|
|
26
|
+
torch.nn.init.kaiming_uniform_ = lambda x, *args, **kwargs: x
|
|
27
|
+
torch.nn.init.uniform_ = lambda x, *args, **kwargs: x
|
|
28
|
+
torch.nn.init.normal_ = lambda x, *args, **kwargs: x
|
|
29
|
+
torch.nn.init.constant_ = lambda x, *args, **kwargs: x
|
|
30
|
+
torch.nn.init.xavier_uniform_ = lambda x, *args, **kwargs: x
|
|
31
|
+
torch.nn.init.xavier_normal_ = lambda x, *args, **kwargs: x
|
|
32
|
+
torch.nn.init.kaiming_normal_ = lambda x, *args, **kwargs: x
|
|
33
|
+
torch.nn.init.orthogonal_ = lambda x, *args, **kwargs: x
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def get_model_parameter_size(model: nn.Module):
|
|
37
|
+
"""to calculate how much memory this model needs"""
|
|
38
|
+
param_size = 0
|
|
39
|
+
param_sum = 0
|
|
40
|
+
for param in model.parameters():
|
|
41
|
+
param_size += param.nelement() * param.element_size()
|
|
42
|
+
param_sum += param.nelement()
|
|
43
|
+
buffer_size = 0
|
|
44
|
+
buffer_sum = 0
|
|
45
|
+
for buffer in model.buffers():
|
|
46
|
+
buffer_size += buffer.nelement() * buffer.element_size()
|
|
47
|
+
buffer_sum += buffer.nelement()
|
|
48
|
+
all_size = (param_size + buffer_size) / 1024 / 1024
|
|
49
|
+
return all_size
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def initialize_model_and_sample_inputs(hf_model: str, cache_dir: Optional[str], tokenizer=None):
|
|
53
|
+
"""
|
|
54
|
+
get the pretrained torch model from hugginface,
|
|
55
|
+
and sample model-inputs
|
|
56
|
+
"""
|
|
57
|
+
|
|
58
|
+
disable_huggingface_init()
|
|
59
|
+
|
|
60
|
+
model = transformers.AutoModelForCausalLM.from_pretrained( # type: ignore
|
|
61
|
+
hf_model, torch_dtype=torch.float16, cache_dir=cache_dir, trust_remote_code=True
|
|
62
|
+
)
|
|
63
|
+
if tokenizer is None:
|
|
64
|
+
tokenizer = hf_model
|
|
65
|
+
tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer) # type: ignore
|
|
66
|
+
|
|
67
|
+
sample_inputs = tuple(tokenizer("Hello, my dog is cute", return_tensors="pt").values())
|
|
68
|
+
return model, sample_inputs
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def auto_pipeline_parallel(model: nn.Module, gpulist: list, sample_inputs: tuple):
|
|
72
|
+
"""Make the model executable across multiple GPUs."""
|
|
73
|
+
|
|
74
|
+
def input_gpu_device_hook(mod, inputs, kwargs):
|
|
75
|
+
modifyed_inputs = []
|
|
76
|
+
first_dev = None
|
|
77
|
+
for layer_input in inputs:
|
|
78
|
+
if type(layer_input) is not torch.Tensor:
|
|
79
|
+
modifyed_inputs.append(layer_input)
|
|
80
|
+
elif hasattr(mod, "weight"):
|
|
81
|
+
modifyed_inputs.append(layer_input.to(mod.weight.device))
|
|
82
|
+
elif hasattr(mod, "parameters"):
|
|
83
|
+
device = next(mod.parameters(), layer_input).device
|
|
84
|
+
modifyed_inputs.append(layer_input.to(device))
|
|
85
|
+
elif hasattr(next(mod.children(), None), "weight"):
|
|
86
|
+
modifyed_inputs.append(layer_input.to(next(mod.children()).weight.device))
|
|
87
|
+
elif first_dev is not None and layer_input.device != first_dev:
|
|
88
|
+
modifyed_inputs.append(layer_input.to(first_dev))
|
|
89
|
+
else:
|
|
90
|
+
modifyed_inputs.append(layer_input)
|
|
91
|
+
if first_dev is None:
|
|
92
|
+
first_dev = modifyed_inputs[0].device
|
|
93
|
+
for key, value in kwargs.items():
|
|
94
|
+
if type(value) is torch.Tensor:
|
|
95
|
+
kwargs[key] = value.to(first_dev)
|
|
96
|
+
|
|
97
|
+
return (tuple(modifyed_inputs), kwargs)
|
|
98
|
+
|
|
99
|
+
def move_layer_to_device_rurc(mod, dev):
|
|
100
|
+
mod.to(dev)
|
|
101
|
+
for layer in mod.named_children():
|
|
102
|
+
move_layer_to_device_rurc(layer[1], dev)
|
|
103
|
+
|
|
104
|
+
model = model.half()
|
|
105
|
+
all_hooks = []
|
|
106
|
+
all_hooks.append(model.register_forward_pre_hook(input_gpu_device_hook, with_kwargs=True))
|
|
107
|
+
pre_fix = next(iter(model.named_children()))[0]
|
|
108
|
+
for top_name, top_module in model.named_children():
|
|
109
|
+
for name, module in top_module.named_children():
|
|
110
|
+
all_hooks.append(module.register_forward_pre_hook(input_gpu_device_hook, with_kwargs=True))
|
|
111
|
+
if type(module) in [torch.nn.ModuleList]:
|
|
112
|
+
num_layers_on_each_gpu = math.floor(len(module) / len(gpulist))
|
|
113
|
+
for idx, attn_layer in enumerate(module):
|
|
114
|
+
all_hooks.append(attn_layer.register_forward_pre_hook(input_gpu_device_hook, with_kwargs=True))
|
|
115
|
+
|
|
116
|
+
to_dev = gpulist[min(idx // num_layers_on_each_gpu, len(gpulist))]
|
|
117
|
+
attn_layer.to(to_dev)
|
|
118
|
+
move_layer_to_device_rurc(attn_layer, to_dev)
|
|
119
|
+
print(f"move {pre_fix}.{name}.{idx} to {to_dev}")
|
|
120
|
+
else:
|
|
121
|
+
module.to(gpulist[0])
|
|
122
|
+
print(f"move {pre_fix}.{name} to {gpulist[0]}")
|
|
123
|
+
if len(list(top_module.named_children())) == 0:
|
|
124
|
+
top_module.to(gpulist[0])
|
|
125
|
+
print(f"move {top_name} to {gpulist[0]}")
|
|
126
|
+
|
|
127
|
+
with torch.no_grad():
|
|
128
|
+
model(sample_inputs[0], attention_mask=sample_inputs[1])
|
|
129
|
+
return model
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def retrieve_onnx_inputs(model: nn.Module, sample_inputs: tuple, with_past: bool):
|
|
133
|
+
"""
|
|
134
|
+
auto retrieve onnx inputs from torch model as we can't enumlate all possibilities
|
|
135
|
+
for all models
|
|
136
|
+
"""
|
|
137
|
+
user_inputs = []
|
|
138
|
+
|
|
139
|
+
def hook_for_inputs(_, inputs, kwargs):
|
|
140
|
+
user_inputs.append((inputs, kwargs))
|
|
141
|
+
return user_inputs[0]
|
|
142
|
+
|
|
143
|
+
hook_handle = model.register_forward_pre_hook(hook_for_inputs, with_kwargs=True)
|
|
144
|
+
|
|
145
|
+
forward_params = inspect.signature(model.forward).parameters
|
|
146
|
+
input_keys = list(forward_params.keys())
|
|
147
|
+
default_values = [forward_params.get(key).default for key in input_keys]
|
|
148
|
+
out = model(sample_inputs[0], attention_mask=sample_inputs[1])
|
|
149
|
+
hook_handle.remove()
|
|
150
|
+
user_inputs = user_inputs[0]
|
|
151
|
+
onnx_inputs = default_values
|
|
152
|
+
for idx, _val in enumerate(user_inputs[0]):
|
|
153
|
+
onnx_inputs[idx] = user_inputs[0][idx]
|
|
154
|
+
for key, value in user_inputs[1].items():
|
|
155
|
+
idx = input_keys.index(key)
|
|
156
|
+
onnx_inputs[idx] = value
|
|
157
|
+
for idx, (key, value) in enumerate(zip(input_keys, onnx_inputs)):
|
|
158
|
+
if type(value) is torch.Tensor:
|
|
159
|
+
value.to(model.device)
|
|
160
|
+
if "use_cache" in key:
|
|
161
|
+
onnx_inputs[idx] = with_past
|
|
162
|
+
out = model(sample_inputs[0], attention_mask=sample_inputs[1], use_cache=with_past) if with_past else out
|
|
163
|
+
|
|
164
|
+
return input_keys, onnx_inputs, out.past_key_values
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def move_to_appropriate_device(model: nn.Module, sample_inputs_tp: tuple) -> nn.Module:
|
|
168
|
+
"""
|
|
169
|
+
According to the model size, we will upload it to
|
|
170
|
+
CPU if has no GPU or enough GPU memory,
|
|
171
|
+
Single GPU if has only one GPU in local or model size is enough to fit one GPU
|
|
172
|
+
Multiple GPU if there is more than one gpu in local and model is too large
|
|
173
|
+
"""
|
|
174
|
+
total_mem_per_cpu = torch.cuda.get_device_properties(0).total_memory / 1024 / 1024
|
|
175
|
+
|
|
176
|
+
print(f"Model_Size = {get_model_parameter_size(model)/1024} GB")
|
|
177
|
+
print(f"total_mem_per_cpu = {total_mem_per_cpu/1024} GB")
|
|
178
|
+
if get_model_parameter_size(model) > total_mem_per_cpu * 0.45:
|
|
179
|
+
device_collection = [torch.device(i) for i in range(torch.cuda.device_count())]
|
|
180
|
+
if len(device_collection) > 1:
|
|
181
|
+
print(
|
|
182
|
+
f"{len(device_collection)} GPUs are used to export onnx, \
|
|
183
|
+
Please set CUDA_VISIBLE_DEVICES to use specific GPU group"
|
|
184
|
+
)
|
|
185
|
+
model = auto_pipeline_parallel(model, device_collection, sample_inputs_tp)
|
|
186
|
+
else:
|
|
187
|
+
print("!!!! convert model to float and export onnx using CPU")
|
|
188
|
+
model = model.cpu().float()
|
|
189
|
+
else:
|
|
190
|
+
print("Export model on a single GPU")
|
|
191
|
+
model = model.cuda().half()
|
|
192
|
+
return model
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
def adapt_inputs_to_device(sample_inputs: tuple, device: torch.device) -> tuple:
|
|
196
|
+
"""move inputs to device"""
|
|
197
|
+
sample_inputs_ = []
|
|
198
|
+
for sample_int in sample_inputs:
|
|
199
|
+
if isinstance(sample_int, torch.Tensor):
|
|
200
|
+
sample_inputs_.append(sample_int.to(device))
|
|
201
|
+
else:
|
|
202
|
+
sample_inputs_.append(sample_int)
|
|
203
|
+
return tuple(sample_inputs_)
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
def fetch_onnx_inputs_outputs_name(
|
|
207
|
+
model: nn.Module,
|
|
208
|
+
onnx_inputs: list,
|
|
209
|
+
torch_input_names: tuple,
|
|
210
|
+
past_key_values: tuple,
|
|
211
|
+
with_past: bool,
|
|
212
|
+
input_with_past: bool,
|
|
213
|
+
):
|
|
214
|
+
"""fetch onnx inputs and outputs name"""
|
|
215
|
+
num_of_past_key = 0
|
|
216
|
+
kv_cache_axis = {0: "batch_size"}
|
|
217
|
+
# try get num_of_past_key and shape of past_key_value
|
|
218
|
+
if past_key_values is not None:
|
|
219
|
+
num_of_past_key = len(past_key_values)
|
|
220
|
+
seq_index = (torch.tensor(past_key_values[0][0].shape) == onnx_inputs[0].shape[-1]).nonzero().view(-1)
|
|
221
|
+
assert seq_index.numel() == 1
|
|
222
|
+
kv_cache_axis = {0: "batch_size", seq_index.item(): "seq_len"}
|
|
223
|
+
|
|
224
|
+
if not num_of_past_key:
|
|
225
|
+
num_of_past_key = model.config.num_hidden_layers
|
|
226
|
+
|
|
227
|
+
# filter out constant inputs
|
|
228
|
+
onnx_inp_names = tuple(
|
|
229
|
+
[torch_input_names[i] for i in range(len(torch_input_names)) if isinstance(onnx_inputs[i], torch.Tensor)]
|
|
230
|
+
)
|
|
231
|
+
assert (
|
|
232
|
+
"input_ids" in onnx_inp_names and "attention_mask" in onnx_inp_names
|
|
233
|
+
), "input_ids and attention_mask must be existed in inputs"
|
|
234
|
+
onnx_out_names = ("logits",)
|
|
235
|
+
onnx_dynamic_axes = {
|
|
236
|
+
"input_ids": {0: "batch_size", 1: "seq_len"},
|
|
237
|
+
"attention_mask": {0: "batch_size", 1: "seq_len"},
|
|
238
|
+
}
|
|
239
|
+
# add dyanmic dimensions for the unkonw inputs
|
|
240
|
+
for idx, name in enumerate(onnx_inp_names):
|
|
241
|
+
if name not in onnx_dynamic_axes:
|
|
242
|
+
unknown_dims = {i: f"{idx}__unknown_dims__{i}" for i in range(onnx_inputs[idx].dim())}
|
|
243
|
+
onnx_dynamic_axes[name] = unknown_dims
|
|
244
|
+
if input_with_past:
|
|
245
|
+
for i in range(num_of_past_key):
|
|
246
|
+
onnx_inp_names += (f"past_key_values.{i}.key",)
|
|
247
|
+
onnx_inp_names += (f"past_key_values.{i}.value",)
|
|
248
|
+
|
|
249
|
+
onnx_dynamic_axes[onnx_inp_names[-1]] = kv_cache_axis
|
|
250
|
+
onnx_dynamic_axes[onnx_inp_names[-2]] = kv_cache_axis
|
|
251
|
+
|
|
252
|
+
if with_past or input_with_past:
|
|
253
|
+
for i in range(num_of_past_key):
|
|
254
|
+
onnx_out_names += (f"present.{i}.key",)
|
|
255
|
+
onnx_out_names += (f"present.{i}.value",)
|
|
256
|
+
|
|
257
|
+
for idx, name in enumerate(torch_input_names):
|
|
258
|
+
if input_with_past:
|
|
259
|
+
if name == "past_key_values":
|
|
260
|
+
onnx_inputs[idx] = past_key_values
|
|
261
|
+
elif name == "attention_mask":
|
|
262
|
+
attn_mask = onnx_inputs[idx]
|
|
263
|
+
onnx_inputs[idx] = torch.cat(
|
|
264
|
+
(attn_mask, torch.ones((attn_mask.shape[0], 1), device=attn_mask.device, dtype=attn_mask.dtype)),
|
|
265
|
+
dim=1,
|
|
266
|
+
)
|
|
267
|
+
elif name == "input_ids":
|
|
268
|
+
input_ids = onnx_inputs[idx]
|
|
269
|
+
onnx_inputs[idx] = input_ids[:, -1:]
|
|
270
|
+
|
|
271
|
+
return onnx_inp_names, onnx_out_names, onnx_dynamic_axes
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
def do_export_internal(model: nn.Module, onnx_io_tuple: tuple, onnx_inputs: tuple, onnx_path: Path, opset: int):
|
|
275
|
+
"""do export with torch.onnx.export"""
|
|
276
|
+
onnx_model_name = onnx_path.name
|
|
277
|
+
onnx_inp_names, onnx_out_names, onnx_dynamic_axes = onnx_io_tuple
|
|
278
|
+
# two step to export onnx
|
|
279
|
+
# 1. export onnx with lots of pieces of weights
|
|
280
|
+
# 2. save all weights to external data
|
|
281
|
+
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
282
|
+
tmp_onnx = os.path.join(tmpdirname, "tmp.onnx")
|
|
283
|
+
|
|
284
|
+
torch.onnx.export(
|
|
285
|
+
model=model,
|
|
286
|
+
args=tuple(onnx_inputs),
|
|
287
|
+
f=tmp_onnx,
|
|
288
|
+
verbose=False,
|
|
289
|
+
opset_version=opset,
|
|
290
|
+
input_names=onnx_inp_names,
|
|
291
|
+
output_names=onnx_out_names,
|
|
292
|
+
dynamic_axes=onnx_dynamic_axes,
|
|
293
|
+
)
|
|
294
|
+
|
|
295
|
+
onnx_path.unlink(missing_ok=True)
|
|
296
|
+
(onnx_path.parent / f"{onnx_model_name}_ext.data").unlink(missing_ok=True)
|
|
297
|
+
|
|
298
|
+
onnx_model = onnx.load(str(tmp_onnx))
|
|
299
|
+
onnx.save_model(
|
|
300
|
+
onnx_model,
|
|
301
|
+
str(onnx_path),
|
|
302
|
+
save_as_external_data=(len(os.listdir(tmpdirname)) > 1),
|
|
303
|
+
all_tensors_to_one_file=True,
|
|
304
|
+
location=f"{onnx_model_name}_ext.data",
|
|
305
|
+
size_threshold=1024,
|
|
306
|
+
convert_attribute=False,
|
|
307
|
+
)
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
@torch.no_grad()
|
|
311
|
+
def export_onnx(hf_model: str, cache_dir: Optional[str], onnx_path_str: str, with_past: bool, opset: int):
|
|
312
|
+
"""
|
|
313
|
+
do export
|
|
314
|
+
model: torch model
|
|
315
|
+
onnx_path: where the onnx model saved to
|
|
316
|
+
sample_inputs_tp: inputs for torch model
|
|
317
|
+
"""
|
|
318
|
+
model, sample_inputs_tp = initialize_model_and_sample_inputs(hf_model, cache_dir)
|
|
319
|
+
|
|
320
|
+
model = move_to_appropriate_device(model, sample_inputs_tp)
|
|
321
|
+
|
|
322
|
+
sample_inputs = adapt_inputs_to_device(sample_inputs_tp, next(model.parameters()).device)
|
|
323
|
+
|
|
324
|
+
# input_keys would be usesful if the model has some special inputs
|
|
325
|
+
input_keys, onnx_inputs, past_key_value = retrieve_onnx_inputs(model, sample_inputs, with_past)
|
|
326
|
+
|
|
327
|
+
onnx_io_tuple = fetch_onnx_inputs_outputs_name(model, onnx_inputs, input_keys, past_key_value, with_past, False)
|
|
328
|
+
|
|
329
|
+
onnx_model_name = "model.onnx"
|
|
330
|
+
onnx_path: Path = Path(onnx_path_str).absolute()
|
|
331
|
+
if onnx_path.suffix != ".onnx":
|
|
332
|
+
onnx_path = onnx_path / onnx_model_name
|
|
333
|
+
|
|
334
|
+
do_export_internal(model, onnx_io_tuple, onnx_inputs, onnx_path, opset)
|
|
335
|
+
if not with_past:
|
|
336
|
+
return
|
|
337
|
+
|
|
338
|
+
onnx_io_tuple = fetch_onnx_inputs_outputs_name(model, onnx_inputs, input_keys, past_key_value, with_past, True)
|
|
339
|
+
|
|
340
|
+
onnx_model_name = "model_with_past.onnx"
|
|
341
|
+
onnx_path = onnx_path.parent / onnx_model_name
|
|
342
|
+
|
|
343
|
+
do_export_internal(model, onnx_io_tuple, onnx_inputs, onnx_path, opset)
|
|
344
|
+
|
|
345
|
+
|
|
346
|
+
def parse_arguments():
|
|
347
|
+
"""arguments parsing."""
|
|
348
|
+
parser = argparse.ArgumentParser()
|
|
349
|
+
|
|
350
|
+
parser.add_argument(
|
|
351
|
+
"-m",
|
|
352
|
+
"--model",
|
|
353
|
+
required=True,
|
|
354
|
+
type=str,
|
|
355
|
+
default=["meta-llama/Llama-2-70b-hf"],
|
|
356
|
+
help="Pre-trained models in huggingface model hub",
|
|
357
|
+
)
|
|
358
|
+
parser.add_argument(
|
|
359
|
+
"-s",
|
|
360
|
+
"--saved_path",
|
|
361
|
+
required=False,
|
|
362
|
+
type=str,
|
|
363
|
+
default="./onnx_models/",
|
|
364
|
+
help="where the onnx model will be saved",
|
|
365
|
+
)
|
|
366
|
+
parser.add_argument(
|
|
367
|
+
"--cache_dir",
|
|
368
|
+
required=False,
|
|
369
|
+
type=str,
|
|
370
|
+
default=None,
|
|
371
|
+
help=("cache directly of huggingface, by setting this to avoid useless downloading if you have one"),
|
|
372
|
+
)
|
|
373
|
+
parser.add_argument(
|
|
374
|
+
"--with_past",
|
|
375
|
+
action="store_true",
|
|
376
|
+
default=False,
|
|
377
|
+
help=("The tool will export onnx without past-key-value by default"),
|
|
378
|
+
)
|
|
379
|
+
parser.add_argument(
|
|
380
|
+
"--opset",
|
|
381
|
+
required=False,
|
|
382
|
+
type=int,
|
|
383
|
+
default=17,
|
|
384
|
+
help=(
|
|
385
|
+
"the opset to save onnx model, \
|
|
386
|
+
try to increase it if this opset doens't have new features you want"
|
|
387
|
+
),
|
|
388
|
+
)
|
|
389
|
+
return parser.parse_args()
|
|
390
|
+
|
|
391
|
+
|
|
392
|
+
if __name__ == "__main__":
|
|
393
|
+
args = parse_arguments()
|
|
394
|
+
|
|
395
|
+
export_onnx(args.model, args.cache_dir, args.saved_path, args.with_past, args.opset)
|
|
@@ -0,0 +1,221 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License.
|
|
4
|
+
# --------------------------------------------------------------------------
|
|
5
|
+
|
|
6
|
+
# It is used to dump machine information for Notebooks
|
|
7
|
+
|
|
8
|
+
import argparse
|
|
9
|
+
import json
|
|
10
|
+
import logging
|
|
11
|
+
import platform
|
|
12
|
+
from os import environ
|
|
13
|
+
from typing import Dict, List
|
|
14
|
+
|
|
15
|
+
import cpuinfo
|
|
16
|
+
import psutil
|
|
17
|
+
from py3nvml.py3nvml import (
|
|
18
|
+
NVMLError,
|
|
19
|
+
nvmlDeviceGetCount,
|
|
20
|
+
nvmlDeviceGetHandleByIndex,
|
|
21
|
+
nvmlDeviceGetMemoryInfo,
|
|
22
|
+
nvmlDeviceGetName,
|
|
23
|
+
nvmlInit,
|
|
24
|
+
nvmlShutdown,
|
|
25
|
+
nvmlSystemGetDriverVersion,
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class MachineInfo:
|
|
30
|
+
"""Class encapsulating Machine Info logic."""
|
|
31
|
+
|
|
32
|
+
def __init__(self, silent=False, logger=None):
|
|
33
|
+
self.silent = silent
|
|
34
|
+
|
|
35
|
+
if logger is None:
|
|
36
|
+
logging.basicConfig(
|
|
37
|
+
format="%(asctime)s - %(name)s - %(levelname)s: %(message)s",
|
|
38
|
+
level=logging.INFO,
|
|
39
|
+
)
|
|
40
|
+
self.logger = logging.getLogger(__name__)
|
|
41
|
+
else:
|
|
42
|
+
self.logger = logger
|
|
43
|
+
|
|
44
|
+
self.machine_info = None
|
|
45
|
+
try:
|
|
46
|
+
self.machine_info = self.get_machine_info()
|
|
47
|
+
except Exception:
|
|
48
|
+
self.logger.exception("Exception in getting machine info.")
|
|
49
|
+
self.machine_info = None
|
|
50
|
+
|
|
51
|
+
def get_machine_info(self):
|
|
52
|
+
"""Get machine info in metric format"""
|
|
53
|
+
gpu_info = self.get_gpu_info_by_nvml()
|
|
54
|
+
cpu_info = cpuinfo.get_cpu_info()
|
|
55
|
+
|
|
56
|
+
machine_info = {
|
|
57
|
+
"gpu": gpu_info,
|
|
58
|
+
"cpu": self.get_cpu_info(),
|
|
59
|
+
"memory": self.get_memory_info(),
|
|
60
|
+
"os": platform.platform(),
|
|
61
|
+
"python": self._try_get(cpu_info, ["python_version"]),
|
|
62
|
+
"packages": self.get_related_packages(),
|
|
63
|
+
"onnxruntime": self.get_onnxruntime_info(),
|
|
64
|
+
"pytorch": self.get_pytorch_info(),
|
|
65
|
+
"tensorflow": self.get_tensorflow_info(),
|
|
66
|
+
}
|
|
67
|
+
return machine_info
|
|
68
|
+
|
|
69
|
+
def get_memory_info(self) -> Dict:
|
|
70
|
+
"""Get memory info"""
|
|
71
|
+
mem = psutil.virtual_memory()
|
|
72
|
+
return {"total": mem.total, "available": mem.available}
|
|
73
|
+
|
|
74
|
+
def _try_get(self, cpu_info: Dict, names: List) -> str:
|
|
75
|
+
for name in names:
|
|
76
|
+
if name in cpu_info:
|
|
77
|
+
value = cpu_info[name]
|
|
78
|
+
if isinstance(value, (list, tuple)):
|
|
79
|
+
return ",".join([str(i) for i in value])
|
|
80
|
+
return value
|
|
81
|
+
return ""
|
|
82
|
+
|
|
83
|
+
def get_cpu_info(self) -> Dict:
|
|
84
|
+
"""Get CPU info"""
|
|
85
|
+
cpu_info = cpuinfo.get_cpu_info()
|
|
86
|
+
|
|
87
|
+
return {
|
|
88
|
+
"brand": self._try_get(cpu_info, ["brand", "brand_raw"]),
|
|
89
|
+
"cores": psutil.cpu_count(logical=False),
|
|
90
|
+
"logical_cores": psutil.cpu_count(logical=True),
|
|
91
|
+
"hz": self._try_get(cpu_info, ["hz_actual"]),
|
|
92
|
+
"l2_cache": self._try_get(cpu_info, ["l2_cache_size"]),
|
|
93
|
+
"flags": self._try_get(cpu_info, ["flags"]),
|
|
94
|
+
"processor": platform.uname().processor,
|
|
95
|
+
}
|
|
96
|
+
|
|
97
|
+
def get_gpu_info_by_nvml(self) -> Dict:
|
|
98
|
+
"""Get GPU info using nvml"""
|
|
99
|
+
gpu_info_list = []
|
|
100
|
+
driver_version = None
|
|
101
|
+
try:
|
|
102
|
+
nvmlInit()
|
|
103
|
+
driver_version = nvmlSystemGetDriverVersion()
|
|
104
|
+
deviceCount = nvmlDeviceGetCount() # noqa: N806
|
|
105
|
+
for i in range(deviceCount):
|
|
106
|
+
handle = nvmlDeviceGetHandleByIndex(i)
|
|
107
|
+
info = nvmlDeviceGetMemoryInfo(handle)
|
|
108
|
+
gpu_info = {}
|
|
109
|
+
gpu_info["memory_total"] = info.total
|
|
110
|
+
gpu_info["memory_available"] = info.free
|
|
111
|
+
gpu_info["name"] = nvmlDeviceGetName(handle)
|
|
112
|
+
gpu_info_list.append(gpu_info)
|
|
113
|
+
nvmlShutdown()
|
|
114
|
+
except NVMLError as error:
|
|
115
|
+
if not self.silent:
|
|
116
|
+
self.logger.error("Error fetching GPU information using nvml: %s", error)
|
|
117
|
+
return None
|
|
118
|
+
|
|
119
|
+
result = {"driver_version": driver_version, "devices": gpu_info_list}
|
|
120
|
+
|
|
121
|
+
if "CUDA_VISIBLE_DEVICES" in environ:
|
|
122
|
+
result["cuda_visible"] = environ["CUDA_VISIBLE_DEVICES"]
|
|
123
|
+
return result
|
|
124
|
+
|
|
125
|
+
def get_related_packages(self) -> List[str]:
|
|
126
|
+
import pkg_resources
|
|
127
|
+
|
|
128
|
+
installed_packages = pkg_resources.working_set
|
|
129
|
+
related_packages = [
|
|
130
|
+
"onnxruntime-gpu",
|
|
131
|
+
"onnxruntime",
|
|
132
|
+
"onnx",
|
|
133
|
+
"transformers",
|
|
134
|
+
"protobuf",
|
|
135
|
+
"sympy",
|
|
136
|
+
"torch",
|
|
137
|
+
"tensorflow",
|
|
138
|
+
"flatbuffers",
|
|
139
|
+
"numpy",
|
|
140
|
+
"onnxconverter-common",
|
|
141
|
+
]
|
|
142
|
+
related_packages_list = {i.key: i.version for i in installed_packages if i.key in related_packages}
|
|
143
|
+
return related_packages_list
|
|
144
|
+
|
|
145
|
+
def get_onnxruntime_info(self) -> Dict:
|
|
146
|
+
try:
|
|
147
|
+
import onnxruntime
|
|
148
|
+
|
|
149
|
+
return {
|
|
150
|
+
"version": onnxruntime.__version__,
|
|
151
|
+
"support_gpu": "CUDAExecutionProvider" in onnxruntime.get_available_providers(),
|
|
152
|
+
}
|
|
153
|
+
except ImportError as error:
|
|
154
|
+
if not self.silent:
|
|
155
|
+
self.logger.exception(error)
|
|
156
|
+
return None
|
|
157
|
+
except Exception as exception:
|
|
158
|
+
if not self.silent:
|
|
159
|
+
self.logger.exception(exception, False)
|
|
160
|
+
return None
|
|
161
|
+
|
|
162
|
+
def get_pytorch_info(self) -> Dict:
|
|
163
|
+
try:
|
|
164
|
+
import torch
|
|
165
|
+
|
|
166
|
+
return {
|
|
167
|
+
"version": torch.__version__,
|
|
168
|
+
"support_gpu": torch.cuda.is_available(),
|
|
169
|
+
"cuda": torch.version.cuda,
|
|
170
|
+
}
|
|
171
|
+
except ImportError as error:
|
|
172
|
+
if not self.silent:
|
|
173
|
+
self.logger.exception(error)
|
|
174
|
+
return None
|
|
175
|
+
except Exception as exception:
|
|
176
|
+
if not self.silent:
|
|
177
|
+
self.logger.exception(exception, False)
|
|
178
|
+
return None
|
|
179
|
+
|
|
180
|
+
def get_tensorflow_info(self) -> Dict:
|
|
181
|
+
try:
|
|
182
|
+
import tensorflow as tf
|
|
183
|
+
|
|
184
|
+
return {
|
|
185
|
+
"version": tf.version.VERSION,
|
|
186
|
+
"git_version": tf.version.GIT_VERSION,
|
|
187
|
+
"support_gpu": tf.test.is_built_with_cuda(),
|
|
188
|
+
}
|
|
189
|
+
except ImportError as error:
|
|
190
|
+
if not self.silent:
|
|
191
|
+
self.logger.exception(error)
|
|
192
|
+
return None
|
|
193
|
+
except ModuleNotFoundError as error:
|
|
194
|
+
if not self.silent:
|
|
195
|
+
self.logger.exception(error)
|
|
196
|
+
return None
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
def parse_arguments():
|
|
200
|
+
parser = argparse.ArgumentParser()
|
|
201
|
+
|
|
202
|
+
parser.add_argument(
|
|
203
|
+
"--silent",
|
|
204
|
+
required=False,
|
|
205
|
+
action="store_true",
|
|
206
|
+
help="Do not print error message",
|
|
207
|
+
)
|
|
208
|
+
parser.set_defaults(silent=False)
|
|
209
|
+
|
|
210
|
+
args = parser.parse_args()
|
|
211
|
+
return args
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
def get_machine_info(silent=True) -> str:
|
|
215
|
+
machine = MachineInfo(silent)
|
|
216
|
+
return json.dumps(machine.machine_info, indent=2)
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
if __name__ == "__main__":
|
|
220
|
+
args = parse_arguments()
|
|
221
|
+
print(get_machine_info(args.silent))
|