onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxruntime/LICENSE +21 -0
- onnxruntime/Privacy.md +21 -0
- onnxruntime/ThirdPartyNotices.txt +6121 -0
- onnxruntime/__init__.py +418 -0
- onnxruntime/backend/__init__.py +6 -0
- onnxruntime/backend/backend.py +175 -0
- onnxruntime/backend/backend_rep.py +52 -0
- onnxruntime/capi/DirectML.dll +0 -0
- onnxruntime/capi/__init__.py +4 -0
- onnxruntime/capi/_ld_preload.py +7 -0
- onnxruntime/capi/_pybind_state.py +33 -0
- onnxruntime/capi/build_and_package_info.py +2 -0
- onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
- onnxruntime/capi/onnxruntime.dll +0 -0
- onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
- onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
- onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
- onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
- onnxruntime/capi/onnxruntime_validation.py +154 -0
- onnxruntime/capi/version_info.py +2 -0
- onnxruntime/datasets/__init__.py +18 -0
- onnxruntime/datasets/logreg_iris.onnx +0 -0
- onnxruntime/datasets/mul_1.onnx +0 -0
- onnxruntime/datasets/sigmoid.onnx +13 -0
- onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
- onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
- onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
- onnxruntime/quantization/__init__.py +19 -0
- onnxruntime/quantization/base_quantizer.py +529 -0
- onnxruntime/quantization/calibrate.py +1267 -0
- onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
- onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
- onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
- onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
- onnxruntime/quantization/fusions/__init__.py +4 -0
- onnxruntime/quantization/fusions/fusion.py +311 -0
- onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
- onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
- onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
- onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
- onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
- onnxruntime/quantization/neural_compressor/__init__.py +1 -0
- onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
- onnxruntime/quantization/neural_compressor/util.py +80 -0
- onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
- onnxruntime/quantization/onnx_model.py +600 -0
- onnxruntime/quantization/onnx_quantizer.py +1163 -0
- onnxruntime/quantization/operators/__init__.py +2 -0
- onnxruntime/quantization/operators/activation.py +119 -0
- onnxruntime/quantization/operators/argmax.py +18 -0
- onnxruntime/quantization/operators/attention.py +73 -0
- onnxruntime/quantization/operators/base_operator.py +26 -0
- onnxruntime/quantization/operators/binary_op.py +72 -0
- onnxruntime/quantization/operators/concat.py +62 -0
- onnxruntime/quantization/operators/conv.py +260 -0
- onnxruntime/quantization/operators/direct_q8.py +78 -0
- onnxruntime/quantization/operators/embed_layernorm.py +121 -0
- onnxruntime/quantization/operators/gather.py +64 -0
- onnxruntime/quantization/operators/gavgpool.py +62 -0
- onnxruntime/quantization/operators/gemm.py +172 -0
- onnxruntime/quantization/operators/lstm.py +121 -0
- onnxruntime/quantization/operators/matmul.py +231 -0
- onnxruntime/quantization/operators/maxpool.py +34 -0
- onnxruntime/quantization/operators/norm.py +40 -0
- onnxruntime/quantization/operators/pad.py +172 -0
- onnxruntime/quantization/operators/pooling.py +67 -0
- onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
- onnxruntime/quantization/operators/resize.py +34 -0
- onnxruntime/quantization/operators/softmax.py +74 -0
- onnxruntime/quantization/operators/split.py +63 -0
- onnxruntime/quantization/operators/where.py +87 -0
- onnxruntime/quantization/preprocess.py +141 -0
- onnxruntime/quantization/qdq_loss_debug.py +389 -0
- onnxruntime/quantization/qdq_quantizer.py +1477 -0
- onnxruntime/quantization/quant_utils.py +1051 -0
- onnxruntime/quantization/quantize.py +953 -0
- onnxruntime/quantization/registry.py +110 -0
- onnxruntime/quantization/shape_inference.py +204 -0
- onnxruntime/quantization/static_quantize_runner.py +256 -0
- onnxruntime/quantization/tensor_quant_overrides.py +520 -0
- onnxruntime/tools/__init__.py +10 -0
- onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
- onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
- onnxruntime/tools/file_utils.py +47 -0
- onnxruntime/tools/logger.py +11 -0
- onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
- onnxruntime/tools/mobile_helpers/__init__.py +0 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
- onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
- onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
- onnxruntime/tools/offline_tuning.py +169 -0
- onnxruntime/tools/onnx_model_utils.py +416 -0
- onnxruntime/tools/onnx_randomizer.py +85 -0
- onnxruntime/tools/onnxruntime_test.py +164 -0
- onnxruntime/tools/optimize_onnx_model.py +56 -0
- onnxruntime/tools/ort_format_model/__init__.py +27 -0
- onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
- onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
- onnxruntime/tools/ort_format_model/types.py +85 -0
- onnxruntime/tools/ort_format_model/utils.py +61 -0
- onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
- onnxruntime/tools/pytorch_export_helpers.py +131 -0
- onnxruntime/tools/qdq_helpers/__init__.py +0 -0
- onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
- onnxruntime/tools/qnn/add_trans_cast.py +292 -0
- onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
- onnxruntime/tools/qnn/preprocess.py +165 -0
- onnxruntime/tools/reduced_build_config_parser.py +203 -0
- onnxruntime/tools/remove_initializer_from_input.py +37 -0
- onnxruntime/tools/symbolic_shape_infer.py +3094 -0
- onnxruntime/tools/update_onnx_opset.py +31 -0
- onnxruntime/transformers/__init__.py +8 -0
- onnxruntime/transformers/affinity_helper.py +40 -0
- onnxruntime/transformers/benchmark.py +942 -0
- onnxruntime/transformers/benchmark_helper.py +643 -0
- onnxruntime/transformers/bert_perf_test.py +629 -0
- onnxruntime/transformers/bert_test_data.py +641 -0
- onnxruntime/transformers/compare_bert_results.py +256 -0
- onnxruntime/transformers/constants.py +47 -0
- onnxruntime/transformers/convert_generation.py +3605 -0
- onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
- onnxruntime/transformers/convert_to_packing_mode.py +385 -0
- onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
- onnxruntime/transformers/float16.py +501 -0
- onnxruntime/transformers/fusion_attention.py +1189 -0
- onnxruntime/transformers/fusion_attention_clip.py +340 -0
- onnxruntime/transformers/fusion_attention_sam2.py +533 -0
- onnxruntime/transformers/fusion_attention_unet.py +1307 -0
- onnxruntime/transformers/fusion_attention_vae.py +300 -0
- onnxruntime/transformers/fusion_bart_attention.py +435 -0
- onnxruntime/transformers/fusion_base.py +141 -0
- onnxruntime/transformers/fusion_bias_add.py +57 -0
- onnxruntime/transformers/fusion_biasgelu.py +66 -0
- onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
- onnxruntime/transformers/fusion_conformer_attention.py +222 -0
- onnxruntime/transformers/fusion_constant_fold.py +144 -0
- onnxruntime/transformers/fusion_embedlayer.py +810 -0
- onnxruntime/transformers/fusion_fastgelu.py +492 -0
- onnxruntime/transformers/fusion_gelu.py +258 -0
- onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
- onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
- onnxruntime/transformers/fusion_gpt_attention.py +546 -0
- onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
- onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
- onnxruntime/transformers/fusion_group_norm.py +180 -0
- onnxruntime/transformers/fusion_layernorm.py +489 -0
- onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
- onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
- onnxruntime/transformers/fusion_options.py +340 -0
- onnxruntime/transformers/fusion_qordered_attention.py +420 -0
- onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
- onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
- onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
- onnxruntime/transformers/fusion_quickgelu.py +74 -0
- onnxruntime/transformers/fusion_reshape.py +173 -0
- onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
- onnxruntime/transformers/fusion_shape.py +109 -0
- onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
- onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
- onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
- onnxruntime/transformers/fusion_transpose.py +167 -0
- onnxruntime/transformers/fusion_utils.py +321 -0
- onnxruntime/transformers/huggingface_models.py +74 -0
- onnxruntime/transformers/import_utils.py +20 -0
- onnxruntime/transformers/io_binding_helper.py +487 -0
- onnxruntime/transformers/large_model_exporter.py +395 -0
- onnxruntime/transformers/machine_info.py +230 -0
- onnxruntime/transformers/metrics.py +163 -0
- onnxruntime/transformers/models/bart/__init__.py +12 -0
- onnxruntime/transformers/models/bart/export.py +98 -0
- onnxruntime/transformers/models/bert/__init__.py +12 -0
- onnxruntime/transformers/models/bert/eval_squad.py +329 -0
- onnxruntime/transformers/models/gpt2/__init__.py +12 -0
- onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
- onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
- onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
- onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
- onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
- onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
- onnxruntime/transformers/models/llama/__init__.py +12 -0
- onnxruntime/transformers/models/llama/benchmark.py +700 -0
- onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
- onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
- onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
- onnxruntime/transformers/models/llama/dist_settings.py +57 -0
- onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
- onnxruntime/transformers/models/llama/llama_parity.py +343 -0
- onnxruntime/transformers/models/llama/llama_torch.py +47 -0
- onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
- onnxruntime/transformers/models/longformer/__init__.py +12 -0
- onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
- onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
- onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
- onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
- onnxruntime/transformers/models/phi2/__init__.py +12 -0
- onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
- onnxruntime/transformers/models/phi2/inference_example.py +414 -0
- onnxruntime/transformers/models/sam2/__init__.py +12 -0
- onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
- onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
- onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
- onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
- onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
- onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
- onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
- onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
- onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
- onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
- onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
- onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
- onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
- onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
- onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
- onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
- onnxruntime/transformers/models/t5/__init__.py +12 -0
- onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
- onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
- onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
- onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
- onnxruntime/transformers/models/t5/t5_helper.py +302 -0
- onnxruntime/transformers/models/whisper/__init__.py +12 -0
- onnxruntime/transformers/models/whisper/benchmark.py +585 -0
- onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
- onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
- onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
- onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
- onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
- onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
- onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
- onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
- onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
- onnxruntime/transformers/onnx_exporter.py +719 -0
- onnxruntime/transformers/onnx_model.py +1636 -0
- onnxruntime/transformers/onnx_model_bart.py +141 -0
- onnxruntime/transformers/onnx_model_bert.py +488 -0
- onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
- onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
- onnxruntime/transformers/onnx_model_clip.py +42 -0
- onnxruntime/transformers/onnx_model_conformer.py +32 -0
- onnxruntime/transformers/onnx_model_gpt2.py +101 -0
- onnxruntime/transformers/onnx_model_mmdit.py +112 -0
- onnxruntime/transformers/onnx_model_phi.py +929 -0
- onnxruntime/transformers/onnx_model_sam2.py +137 -0
- onnxruntime/transformers/onnx_model_t5.py +985 -0
- onnxruntime/transformers/onnx_model_tnlr.py +226 -0
- onnxruntime/transformers/onnx_model_unet.py +258 -0
- onnxruntime/transformers/onnx_model_vae.py +42 -0
- onnxruntime/transformers/onnx_utils.py +55 -0
- onnxruntime/transformers/optimizer.py +620 -0
- onnxruntime/transformers/past_helper.py +149 -0
- onnxruntime/transformers/profile_result_processor.py +358 -0
- onnxruntime/transformers/profiler.py +434 -0
- onnxruntime/transformers/quantize_helper.py +76 -0
- onnxruntime/transformers/shape_infer_helper.py +121 -0
- onnxruntime/transformers/shape_optimizer.py +400 -0
- onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
- onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
- onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
- onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
- onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
- onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,3094 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
# -*- coding: UTF-8 -*-
|
|
5
|
+
import argparse
|
|
6
|
+
import logging
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
import onnx
|
|
10
|
+
import sympy
|
|
11
|
+
from onnx import helper, numpy_helper, shape_inference
|
|
12
|
+
from packaging import version
|
|
13
|
+
|
|
14
|
+
assert version.parse(onnx.__version__) >= version.parse("1.8.0")
|
|
15
|
+
|
|
16
|
+
logger = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def get_attribute(node, attr_name, default_value=None):
|
|
20
|
+
found = [attr for attr in node.attribute if attr.name == attr_name]
|
|
21
|
+
if found:
|
|
22
|
+
return helper.get_attribute_value(found[0])
|
|
23
|
+
return default_value
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def get_dim_from_proto(dim):
|
|
27
|
+
return getattr(dim, dim.WhichOneof("value")) if type(dim.WhichOneof("value")) is str else None
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def is_sequence(type_proto):
|
|
31
|
+
cls_type = type_proto.WhichOneof("value")
|
|
32
|
+
assert cls_type in ["tensor_type", "sequence_type"]
|
|
33
|
+
return cls_type == "sequence_type"
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def get_shape_from_type_proto(type_proto):
|
|
37
|
+
assert not is_sequence(type_proto)
|
|
38
|
+
if type_proto.tensor_type.HasField("shape"):
|
|
39
|
+
return [get_dim_from_proto(d) for d in type_proto.tensor_type.shape.dim]
|
|
40
|
+
else:
|
|
41
|
+
return None # note no shape is different from shape without dim (scalar)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def get_elem_type_from_type_proto(type_proto):
|
|
45
|
+
if is_sequence(type_proto):
|
|
46
|
+
return type_proto.sequence_type.elem_type.tensor_type.elem_type
|
|
47
|
+
else:
|
|
48
|
+
return type_proto.tensor_type.elem_type
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def get_shape_from_value_info(vi):
|
|
52
|
+
cls_type = vi.type.WhichOneof("value")
|
|
53
|
+
if cls_type is None:
|
|
54
|
+
return None
|
|
55
|
+
if is_sequence(vi.type):
|
|
56
|
+
if vi.type.sequence_type.elem_type.WhichOneof("value") == "tensor_type":
|
|
57
|
+
return get_shape_from_type_proto(vi.type.sequence_type.elem_type)
|
|
58
|
+
else:
|
|
59
|
+
return None
|
|
60
|
+
else:
|
|
61
|
+
return get_shape_from_type_proto(vi.type)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def make_named_value_info(name):
|
|
65
|
+
vi = onnx.ValueInfoProto()
|
|
66
|
+
vi.name = name
|
|
67
|
+
return vi
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def get_shape_from_sympy_shape(sympy_shape):
|
|
71
|
+
return [None if i is None else (int(i) if is_literal(i) else str(i)) for i in sympy_shape]
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def is_literal(dim):
|
|
75
|
+
return type(dim) in [int, np.int64, np.int32, sympy.Integer] or (hasattr(dim, "is_number") and dim.is_number)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def handle_negative_axis(axis, rank):
|
|
79
|
+
assert axis < rank and axis >= -rank
|
|
80
|
+
return axis if axis >= 0 else rank + axis
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def get_opset(mp, domain=None):
|
|
84
|
+
domain = domain or ["", "onnx", "ai.onnx"]
|
|
85
|
+
if type(domain) != list: # noqa: E721
|
|
86
|
+
domain = [domain]
|
|
87
|
+
for opset in mp.opset_import:
|
|
88
|
+
if opset.domain in domain:
|
|
89
|
+
return opset.version
|
|
90
|
+
|
|
91
|
+
return None
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def as_scalar(x):
|
|
95
|
+
if type(x) is list:
|
|
96
|
+
assert len(x) == 1
|
|
97
|
+
return x[0]
|
|
98
|
+
elif type(x) is np.ndarray:
|
|
99
|
+
return x.item()
|
|
100
|
+
else:
|
|
101
|
+
return x
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def as_list(x, keep_none):
|
|
105
|
+
if type(x) is list:
|
|
106
|
+
return x
|
|
107
|
+
elif type(x) is np.ndarray:
|
|
108
|
+
return list(x)
|
|
109
|
+
elif keep_none and x is None:
|
|
110
|
+
return None
|
|
111
|
+
else:
|
|
112
|
+
return [x]
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def sympy_reduce_product(x):
|
|
116
|
+
if type(x) is list:
|
|
117
|
+
value = sympy.Integer(1)
|
|
118
|
+
for v in x:
|
|
119
|
+
value = value * v
|
|
120
|
+
else:
|
|
121
|
+
value = x
|
|
122
|
+
return value
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
class SymbolicShapeInference:
|
|
126
|
+
def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix=""):
|
|
127
|
+
self.dispatcher_ = {
|
|
128
|
+
"Add": self._infer_symbolic_compute_ops,
|
|
129
|
+
"AllReduce": self._pass_on_shape_and_type,
|
|
130
|
+
"ArrayFeatureExtractor": self._infer_ArrayFeatureExtractor,
|
|
131
|
+
"AveragePool": self._infer_Pool,
|
|
132
|
+
"BatchNormalization": self._infer_BatchNormalization,
|
|
133
|
+
"Cast": self._infer_Cast,
|
|
134
|
+
"CategoryMapper": self._infer_CategoryMapper,
|
|
135
|
+
"Compress": self._infer_Compress,
|
|
136
|
+
"Concat": self._infer_Concat,
|
|
137
|
+
"ConcatFromSequence": self._infer_ConcatFromSequence,
|
|
138
|
+
"Constant": self._infer_Constant,
|
|
139
|
+
"ConstantOfShape": self._infer_ConstantOfShape,
|
|
140
|
+
"Conv": self._infer_Conv,
|
|
141
|
+
"CumSum": self._pass_on_shape_and_type,
|
|
142
|
+
"Div": self._infer_symbolic_compute_ops,
|
|
143
|
+
"Einsum": self._infer_Einsum,
|
|
144
|
+
"Expand": self._infer_Expand,
|
|
145
|
+
"Equal": self._infer_symbolic_compute_ops,
|
|
146
|
+
"Floor": self._infer_symbolic_compute_ops,
|
|
147
|
+
"Gather": self._infer_Gather,
|
|
148
|
+
"GatherElements": self._infer_GatherElements,
|
|
149
|
+
"GatherND": self._infer_GatherND,
|
|
150
|
+
"Identity": self._pass_on_shape_and_type,
|
|
151
|
+
"If": self._infer_If,
|
|
152
|
+
"Loop": self._infer_Loop,
|
|
153
|
+
"MatMul": self._infer_MatMul,
|
|
154
|
+
"MatMulInteger16": self._infer_MatMulInteger,
|
|
155
|
+
"MaxPool": self._infer_Pool,
|
|
156
|
+
"Max": self._infer_symbolic_compute_ops,
|
|
157
|
+
"MemcpyFromHost": self._pass_on_shape_and_type,
|
|
158
|
+
"MemcpyToHost": self._pass_on_shape_and_type,
|
|
159
|
+
"Min": self._infer_symbolic_compute_ops,
|
|
160
|
+
"MoE": self._pass_on_shape_and_type,
|
|
161
|
+
"Mul": self._infer_symbolic_compute_ops,
|
|
162
|
+
"NonMaxSuppression": self._infer_NonMaxSuppression,
|
|
163
|
+
"NonZero": self._infer_NonZero,
|
|
164
|
+
"OneHot": self._infer_OneHot,
|
|
165
|
+
"Pad": self._infer_Pad,
|
|
166
|
+
"Range": self._infer_Range,
|
|
167
|
+
"Reciprocal": self._pass_on_shape_and_type,
|
|
168
|
+
"ReduceSum": self._infer_ReduceSum,
|
|
169
|
+
"ReduceMean": self._infer_ReduceMean,
|
|
170
|
+
"ReduceProd": self._infer_ReduceProd,
|
|
171
|
+
"Reshape": self._infer_Reshape,
|
|
172
|
+
"Resize": self._infer_Resize,
|
|
173
|
+
"Round": self._pass_on_shape_and_type,
|
|
174
|
+
"Scan": self._infer_Scan,
|
|
175
|
+
"ScatterElements": self._infer_ScatterElements,
|
|
176
|
+
"SequenceAt": self._infer_SequenceAt,
|
|
177
|
+
"SequenceInsert": self._infer_SequenceInsert,
|
|
178
|
+
"Shape": self._infer_Shape,
|
|
179
|
+
"Size": self._infer_Size,
|
|
180
|
+
"Slice": self._infer_Slice,
|
|
181
|
+
"SoftmaxCrossEntropyLoss": self._infer_SoftmaxCrossEntropyLoss,
|
|
182
|
+
"SoftmaxCrossEntropyLossInternal": self._infer_SoftmaxCrossEntropyLoss,
|
|
183
|
+
"NegativeLogLikelihoodLossInternal": self._infer_SoftmaxCrossEntropyLoss,
|
|
184
|
+
"Split": self._infer_Split,
|
|
185
|
+
"SplitToSequence": self._infer_SplitToSequence,
|
|
186
|
+
"Squeeze": self._infer_Squeeze,
|
|
187
|
+
"Sub": self._infer_symbolic_compute_ops,
|
|
188
|
+
"Tile": self._infer_Tile,
|
|
189
|
+
"TopK": self._infer_TopK,
|
|
190
|
+
"Transpose": self._infer_Transpose,
|
|
191
|
+
"Unsqueeze": self._infer_Unsqueeze,
|
|
192
|
+
"Where": self._infer_symbolic_compute_ops,
|
|
193
|
+
"ZipMap": self._infer_ZipMap,
|
|
194
|
+
"Neg": self._infer_symbolic_compute_ops,
|
|
195
|
+
# contrib ops:
|
|
196
|
+
"Attention": self._infer_Attention,
|
|
197
|
+
"BiasAdd": self._infer_BiasAdd,
|
|
198
|
+
"BiasGelu": self._infer_BiasGelu,
|
|
199
|
+
"BiasSplitGelu": self._infer_BiasSplitGelu,
|
|
200
|
+
"DecoderMaskedMultiHeadAttention": self._infer_DecoderMaskedMultiHeadAttention,
|
|
201
|
+
"DequantizeLinear": self._infer_DequantizeLinear,
|
|
202
|
+
"DynamicTimeWarping": self._infer_DynamicTimeWarping,
|
|
203
|
+
"EmbedLayerNormalization": self._infer_EmbedLayerNormalization,
|
|
204
|
+
"FastGelu": self._infer_FastGelu,
|
|
205
|
+
"GatedRelativePositionBias": self._infer_GatedRelativePositionBias,
|
|
206
|
+
"GatherBlockQuantized": self._infer_Gather,
|
|
207
|
+
"Gelu": self._infer_Gelu,
|
|
208
|
+
"GemmFastGelu": self._infer_GemmFastGelu,
|
|
209
|
+
"GemmFloat8": self._infer_GemmFloat8,
|
|
210
|
+
"GroupNorm": self._infer_GroupNorm,
|
|
211
|
+
"GroupNormalization": self._infer_GroupNorm,
|
|
212
|
+
"GroupQueryAttention": self._infer_GroupQueryAttention,
|
|
213
|
+
"LayerNormalization": self._infer_LayerNormalization,
|
|
214
|
+
"LongformerAttention": self._infer_LongformerAttention,
|
|
215
|
+
"MatMulNBits": self._infer_MatMulNBits,
|
|
216
|
+
"MultiHeadAttention": self._infer_MultiHeadAttention,
|
|
217
|
+
"NhwcConv": self._infer_NhwcConv,
|
|
218
|
+
"PackedAttention": self._infer_PackedAttention,
|
|
219
|
+
"PackedMultiHeadAttention": self._infer_PackedMultiHeadAttention,
|
|
220
|
+
"PagedAttention": self._infer_PagedAttention,
|
|
221
|
+
"PythonOp": self._infer_PythonOp,
|
|
222
|
+
"QLinearAdd": self._infer_QLinearBinary,
|
|
223
|
+
"QLinearMul": self._infer_QLinearBinary,
|
|
224
|
+
"QuantizeLinear": self._infer_QuantizeLinear,
|
|
225
|
+
"QuickGelu": self._infer_FastGelu,
|
|
226
|
+
"RelativePositionBias": self._infer_RelativePositionBias,
|
|
227
|
+
"RemovePadding": self._infer_RemovePadding,
|
|
228
|
+
"RestorePadding": self._infer_RestorePadding,
|
|
229
|
+
"RotaryEmbedding": self._infer_RotaryEmbedding,
|
|
230
|
+
"SimplifiedLayerNormalization": self._infer_LayerNormalization,
|
|
231
|
+
"SkipGroupNorm": self._infer_SkipGroupNorm,
|
|
232
|
+
"SkipLayerNormalization": self._infer_SkipLayerNormalization,
|
|
233
|
+
"SkipSimplifiedLayerNormalization": self._infer_SkipLayerNormalization,
|
|
234
|
+
"SparseAttention": self._infer_SparseAttention,
|
|
235
|
+
"UnfoldTensor": self._infer_UnfoldTensor,
|
|
236
|
+
}
|
|
237
|
+
self.aten_op_dispatcher_ = {
|
|
238
|
+
"embedding": self._infer_Gather,
|
|
239
|
+
"bitwise_or": self._infer_aten_bitwise_or,
|
|
240
|
+
"diagonal": self._infer_aten_diagonal,
|
|
241
|
+
"max_pool2d_with_indices": self._infer_aten_pool2d,
|
|
242
|
+
"max": self._infer_aten_minmax,
|
|
243
|
+
"min": self._infer_aten_minmax,
|
|
244
|
+
"multinomial": self._infer_aten_multinomial,
|
|
245
|
+
"unfold": self._infer_aten_unfold,
|
|
246
|
+
"argmax": self._infer_aten_argmax,
|
|
247
|
+
"avg_pool2d": self._infer_aten_pool2d,
|
|
248
|
+
"_adaptive_avg_pool2d": self._infer_aten_pool2d,
|
|
249
|
+
"numpy_T": self._infer_Transpose,
|
|
250
|
+
"native_group_norm": self._infer_aten_group_norm,
|
|
251
|
+
"upsample_nearest1d": self._infer_aten_upsample,
|
|
252
|
+
"upsample_nearest2d": self._infer_aten_upsample,
|
|
253
|
+
"upsample_nearest3d": self._infer_aten_upsample,
|
|
254
|
+
"upsample_bicubic2d": self._infer_aten_upsample,
|
|
255
|
+
}
|
|
256
|
+
self.run_ = True
|
|
257
|
+
self.suggested_merge_ = {}
|
|
258
|
+
self.symbolic_dims_ = {}
|
|
259
|
+
self.input_symbols_ = {}
|
|
260
|
+
self.auto_merge_ = auto_merge
|
|
261
|
+
self.guess_output_rank_ = guess_output_rank
|
|
262
|
+
self.verbose_ = verbose
|
|
263
|
+
self.int_max_ = int_max
|
|
264
|
+
self.subgraph_id_ = 0
|
|
265
|
+
self.prefix_ = prefix
|
|
266
|
+
|
|
267
|
+
def _add_suggested_merge(self, symbols, apply=False):
|
|
268
|
+
assert all((type(s) is str and s in self.symbolic_dims_) or is_literal(s) for s in symbols)
|
|
269
|
+
symbols = set(symbols)
|
|
270
|
+
for k, v in self.suggested_merge_.items():
|
|
271
|
+
if k in symbols:
|
|
272
|
+
symbols.remove(k)
|
|
273
|
+
symbols.add(v)
|
|
274
|
+
map_to = None
|
|
275
|
+
# if there is literal, map to it first
|
|
276
|
+
for s in symbols:
|
|
277
|
+
if is_literal(s):
|
|
278
|
+
map_to = s
|
|
279
|
+
break
|
|
280
|
+
# when no literals, map to input symbolic dims, then existing symbolic dims
|
|
281
|
+
if map_to is None:
|
|
282
|
+
for s in symbols:
|
|
283
|
+
if s in self.input_symbols_:
|
|
284
|
+
map_to = s
|
|
285
|
+
break
|
|
286
|
+
if map_to is None:
|
|
287
|
+
for s in symbols:
|
|
288
|
+
if type(self.symbolic_dims_[s]) is sympy.Symbol:
|
|
289
|
+
map_to = s
|
|
290
|
+
break
|
|
291
|
+
# when nothing to map to, use the shorter one
|
|
292
|
+
if map_to is None:
|
|
293
|
+
if self.verbose_ > 0:
|
|
294
|
+
logger.warning("Potential unsafe merge between symbolic expressions: (%s)", ",".join(symbols))
|
|
295
|
+
symbols_list = list(symbols)
|
|
296
|
+
lens = [len(s) for s in symbols_list]
|
|
297
|
+
map_to = symbols_list[lens.index(min(lens))]
|
|
298
|
+
symbols.remove(map_to)
|
|
299
|
+
|
|
300
|
+
for s in symbols:
|
|
301
|
+
if s == map_to:
|
|
302
|
+
continue
|
|
303
|
+
if is_literal(map_to) and is_literal(s):
|
|
304
|
+
assert int(map_to) == int(s)
|
|
305
|
+
self.suggested_merge_[s] = int(map_to) if is_literal(map_to) else map_to
|
|
306
|
+
for k, v in self.suggested_merge_.items():
|
|
307
|
+
if v == s:
|
|
308
|
+
self.suggested_merge_[k] = map_to
|
|
309
|
+
if apply and self.auto_merge_:
|
|
310
|
+
self._apply_suggested_merge()
|
|
311
|
+
|
|
312
|
+
def _apply_suggested_merge(self, graph_input_only=False):
|
|
313
|
+
if not self.suggested_merge_:
|
|
314
|
+
return
|
|
315
|
+
for i in list(self.out_mp_.graph.input) + ([] if graph_input_only else list(self.out_mp_.graph.value_info)):
|
|
316
|
+
for d in i.type.tensor_type.shape.dim:
|
|
317
|
+
if d.dim_param in self.suggested_merge_:
|
|
318
|
+
v = self.suggested_merge_[d.dim_param]
|
|
319
|
+
if is_literal(v):
|
|
320
|
+
d.dim_value = int(v)
|
|
321
|
+
else:
|
|
322
|
+
d.dim_param = v
|
|
323
|
+
|
|
324
|
+
def _preprocess(self, in_mp):
|
|
325
|
+
self.out_mp_ = onnx.ModelProto()
|
|
326
|
+
self.out_mp_.CopyFrom(in_mp)
|
|
327
|
+
self.graph_inputs_ = {i.name: i for i in list(self.out_mp_.graph.input)}
|
|
328
|
+
self.initializers_ = {i.name: i for i in self.out_mp_.graph.initializer}
|
|
329
|
+
self.known_vi_ = {i.name: i for i in list(self.out_mp_.graph.input)}
|
|
330
|
+
self.known_vi_.update(
|
|
331
|
+
{
|
|
332
|
+
i.name: helper.make_tensor_value_info(i.name, i.data_type, list(i.dims))
|
|
333
|
+
for i in self.out_mp_.graph.initializer
|
|
334
|
+
}
|
|
335
|
+
)
|
|
336
|
+
|
|
337
|
+
def _merge_symbols(self, dims):
|
|
338
|
+
if not all(type(d) is str for d in dims):
|
|
339
|
+
if self.auto_merge_:
|
|
340
|
+
unique_dims = list(set(dims))
|
|
341
|
+
is_int = [is_literal(d) for d in unique_dims]
|
|
342
|
+
assert sum(is_int) <= 1 # if there are more than 1 unique ints, something is wrong
|
|
343
|
+
if sum(is_int) == 1:
|
|
344
|
+
int_dim = is_int.index(1)
|
|
345
|
+
if self.verbose_ > 0:
|
|
346
|
+
logger.debug(
|
|
347
|
+
f"dim {unique_dims[:int_dim] + unique_dims[int_dim + 1 :]} has been merged with value {unique_dims[int_dim]}"
|
|
348
|
+
)
|
|
349
|
+
self._check_merged_dims(unique_dims, allow_broadcast=False)
|
|
350
|
+
return unique_dims[int_dim]
|
|
351
|
+
else:
|
|
352
|
+
if self.verbose_ > 0:
|
|
353
|
+
logger.debug(f"dim {unique_dims[1:]} has been merged with dim {unique_dims[0]}")
|
|
354
|
+
return dims[0]
|
|
355
|
+
else:
|
|
356
|
+
return None
|
|
357
|
+
if all(d == dims[0] for d in dims):
|
|
358
|
+
return dims[0]
|
|
359
|
+
merged = [self.suggested_merge_.get(d, d) for d in dims]
|
|
360
|
+
if all(d == merged[0] for d in merged):
|
|
361
|
+
assert merged[0] in self.symbolic_dims_
|
|
362
|
+
return merged[0]
|
|
363
|
+
else:
|
|
364
|
+
return None
|
|
365
|
+
|
|
366
|
+
# broadcast from right to left, and merge symbolic dims if needed
|
|
367
|
+
def _broadcast_shapes(self, shape1, shape2):
|
|
368
|
+
new_shape = []
|
|
369
|
+
rank1 = len(shape1)
|
|
370
|
+
rank2 = len(shape2)
|
|
371
|
+
new_rank = max(rank1, rank2)
|
|
372
|
+
for i in range(new_rank):
|
|
373
|
+
dim1 = shape1[rank1 - 1 - i] if i < rank1 else 1
|
|
374
|
+
dim2 = shape2[rank2 - 1 - i] if i < rank2 else 1
|
|
375
|
+
if dim1 == 1 or dim1 == dim2:
|
|
376
|
+
new_dim = dim2
|
|
377
|
+
elif dim2 == 1:
|
|
378
|
+
new_dim = dim1
|
|
379
|
+
else:
|
|
380
|
+
new_dim = self._merge_symbols([dim1, dim2])
|
|
381
|
+
if not new_dim:
|
|
382
|
+
# warning about unsupported broadcast when not auto merge
|
|
383
|
+
# note that auto merge has the risk of incorrectly merge symbols while one of them being 1
|
|
384
|
+
# for example, 'a' = 1, 'b' = 5 at runtime is valid broadcasting, but with auto merge 'a' == 'b'
|
|
385
|
+
if self.auto_merge_:
|
|
386
|
+
self._add_suggested_merge([dim1, dim2], apply=True)
|
|
387
|
+
else:
|
|
388
|
+
logger.warning("unsupported broadcast between " + str(dim1) + " " + str(dim2)) # noqa: G003
|
|
389
|
+
new_shape = [new_dim, *new_shape]
|
|
390
|
+
return new_shape
|
|
391
|
+
|
|
392
|
+
def _get_shape(self, node, idx):
|
|
393
|
+
name = node.input[idx]
|
|
394
|
+
if name in self.known_vi_:
|
|
395
|
+
vi = self.known_vi_[name]
|
|
396
|
+
return get_shape_from_value_info(vi)
|
|
397
|
+
else:
|
|
398
|
+
assert name in self.initializers_
|
|
399
|
+
return list(self.initializers_[name].dims)
|
|
400
|
+
|
|
401
|
+
def _try_get_shape(self, node, idx):
|
|
402
|
+
if idx > len(node.input) - 1:
|
|
403
|
+
return None
|
|
404
|
+
name = node.input[idx]
|
|
405
|
+
if name in self.known_vi_:
|
|
406
|
+
vi = self.known_vi_[name]
|
|
407
|
+
return get_shape_from_value_info(vi)
|
|
408
|
+
if name in self.initializers_:
|
|
409
|
+
return list(self.initializers_[name].dims)
|
|
410
|
+
return None
|
|
411
|
+
|
|
412
|
+
def _get_shape_rank(self, node, idx):
|
|
413
|
+
return len(self._get_shape(node, idx))
|
|
414
|
+
|
|
415
|
+
def _get_sympy_shape(self, node, idx):
|
|
416
|
+
sympy_shape = []
|
|
417
|
+
for d in self._get_shape(node, idx):
|
|
418
|
+
if type(d) is str:
|
|
419
|
+
sympy_shape.append(
|
|
420
|
+
self.symbolic_dims_[d]
|
|
421
|
+
if d in self.symbolic_dims_
|
|
422
|
+
else sympy.Symbol(d, integer=True, nonnegative=True)
|
|
423
|
+
)
|
|
424
|
+
else:
|
|
425
|
+
assert None is not d
|
|
426
|
+
sympy_shape.append(d)
|
|
427
|
+
return sympy_shape
|
|
428
|
+
|
|
429
|
+
def _get_value(self, node, idx):
|
|
430
|
+
name = node.input[idx]
|
|
431
|
+
assert name in self.sympy_data_ or name in self.initializers_
|
|
432
|
+
return self.sympy_data_[name] if name in self.sympy_data_ else numpy_helper.to_array(self.initializers_[name])
|
|
433
|
+
|
|
434
|
+
def _try_get_value(self, node, idx):
|
|
435
|
+
if idx >= len(node.input):
|
|
436
|
+
return None
|
|
437
|
+
name = node.input[idx]
|
|
438
|
+
if name in self.sympy_data_ or name in self.initializers_:
|
|
439
|
+
return self._get_value(node, idx)
|
|
440
|
+
return None
|
|
441
|
+
|
|
442
|
+
def _update_computed_dims(self, new_sympy_shape):
|
|
443
|
+
for i, new_dim in enumerate(new_sympy_shape):
|
|
444
|
+
if not is_literal(new_dim) and type(new_dim) != str: # noqa: E721
|
|
445
|
+
str_dim = str(new_dim)
|
|
446
|
+
if str_dim in self.suggested_merge_:
|
|
447
|
+
if is_literal(self.suggested_merge_[str_dim]):
|
|
448
|
+
continue # no need to create dim for literals
|
|
449
|
+
new_sympy_shape[i] = self.symbolic_dims_[self.suggested_merge_[str_dim]]
|
|
450
|
+
else:
|
|
451
|
+
# add new_dim if it's a computational expression
|
|
452
|
+
if str(new_dim) not in self.symbolic_dims_:
|
|
453
|
+
self.symbolic_dims_[str(new_dim)] = new_dim
|
|
454
|
+
|
|
455
|
+
def _onnx_infer_single_node(self, node):
|
|
456
|
+
# skip onnx shape inference for some ops, as they are handled in _infer_*
|
|
457
|
+
skip_infer = node.op_type in [
|
|
458
|
+
"If",
|
|
459
|
+
"Loop",
|
|
460
|
+
"Scan",
|
|
461
|
+
"SplitToSequence",
|
|
462
|
+
"ZipMap", # contrib ops
|
|
463
|
+
"Attention",
|
|
464
|
+
"BiasAdd",
|
|
465
|
+
"BiasGelu",
|
|
466
|
+
"BiasSplitGelu",
|
|
467
|
+
"DequantizeLinear",
|
|
468
|
+
"DynamicTimeWarping",
|
|
469
|
+
"EmbedLayerNormalization",
|
|
470
|
+
"FastGelu",
|
|
471
|
+
"GatherBlockQuantized",
|
|
472
|
+
"Gelu",
|
|
473
|
+
"GemmFastGelu",
|
|
474
|
+
"GroupNorm",
|
|
475
|
+
"GroupNormalization",
|
|
476
|
+
"GroupQueryAttention",
|
|
477
|
+
"LayerNormalization",
|
|
478
|
+
"LongformerAttention",
|
|
479
|
+
"MultiHeadAttention",
|
|
480
|
+
"NhwcConv",
|
|
481
|
+
"PackedAttention",
|
|
482
|
+
"PagedAttention",
|
|
483
|
+
"PythonOp",
|
|
484
|
+
"QuantizeLinear",
|
|
485
|
+
"QuickGelu",
|
|
486
|
+
"RelativePositionBias",
|
|
487
|
+
"RemovePadding",
|
|
488
|
+
"RestorePadding",
|
|
489
|
+
"RotaryEmbedding",
|
|
490
|
+
"SimplifiedLayerNormalization",
|
|
491
|
+
"SkipLayerNormalization",
|
|
492
|
+
"SkipSimplifiedLayerNormalization",
|
|
493
|
+
"SparseAttention",
|
|
494
|
+
"SkipGroupNorm",
|
|
495
|
+
"QLinearAdd",
|
|
496
|
+
"QLinearMul",
|
|
497
|
+
]
|
|
498
|
+
|
|
499
|
+
if not skip_infer:
|
|
500
|
+
# Only pass initializers that satisfy the following condition:
|
|
501
|
+
# (1) Operator need value of some input for shape inference.
|
|
502
|
+
# For example, Unsqueeze in opset 13 uses the axes input to calculate shape of output.
|
|
503
|
+
# (2) opset version >= 9. In older version, initializer is required in graph input by onnx spec.
|
|
504
|
+
# (3) The initializer is not in graph input. The means the node input is "constant" in inference.
|
|
505
|
+
initializers = []
|
|
506
|
+
if (get_opset(self.out_mp_) >= 9) and node.op_type in ["Unsqueeze"]:
|
|
507
|
+
initializers = [
|
|
508
|
+
self.initializers_[name]
|
|
509
|
+
for name in node.input
|
|
510
|
+
if (name in self.initializers_ and name not in self.graph_inputs_)
|
|
511
|
+
]
|
|
512
|
+
|
|
513
|
+
if node.op_type in [
|
|
514
|
+
"Add",
|
|
515
|
+
"Sub",
|
|
516
|
+
"Mul",
|
|
517
|
+
"Div",
|
|
518
|
+
"MatMul",
|
|
519
|
+
"MatMulInteger",
|
|
520
|
+
"MatMulInteger16",
|
|
521
|
+
"Where",
|
|
522
|
+
"Sum",
|
|
523
|
+
]:
|
|
524
|
+
if node.output[0] in self.known_vi_:
|
|
525
|
+
vi = self.known_vi_[node.output[0]]
|
|
526
|
+
out_rank = len(get_shape_from_type_proto(vi.type))
|
|
527
|
+
in_shapes = [self._get_shape(node, i) for i in range(len(node.input))]
|
|
528
|
+
for d in range(
|
|
529
|
+
out_rank - (2 if node.op_type in ["MatMul", "MatMulInteger", "MatMulInteger16"] else 0)
|
|
530
|
+
):
|
|
531
|
+
in_dims = [s[len(s) - out_rank + d] for s in in_shapes if len(s) + d >= out_rank]
|
|
532
|
+
if len(in_dims) > 1:
|
|
533
|
+
self._check_merged_dims(in_dims, allow_broadcast=True)
|
|
534
|
+
|
|
535
|
+
# run single node inference with self.known_vi_ shapes
|
|
536
|
+
tmp_graph = helper.make_graph(
|
|
537
|
+
[node],
|
|
538
|
+
"tmp",
|
|
539
|
+
[self.known_vi_[i] for i in node.input if i],
|
|
540
|
+
[make_named_value_info(i) for i in node.output],
|
|
541
|
+
initializers,
|
|
542
|
+
)
|
|
543
|
+
|
|
544
|
+
self.tmp_mp_.graph.CopyFrom(tmp_graph)
|
|
545
|
+
|
|
546
|
+
self.tmp_mp_ = shape_inference.infer_shapes(self.tmp_mp_)
|
|
547
|
+
|
|
548
|
+
for i_o in range(len(node.output)):
|
|
549
|
+
o = node.output[i_o]
|
|
550
|
+
if o: # skip optional output
|
|
551
|
+
vi = self.out_mp_.graph.value_info.add()
|
|
552
|
+
if not skip_infer:
|
|
553
|
+
vi.CopyFrom(self.tmp_mp_.graph.output[i_o])
|
|
554
|
+
else:
|
|
555
|
+
vi.name = o
|
|
556
|
+
self.known_vi_[o] = vi
|
|
557
|
+
|
|
558
|
+
def _onnx_infer_subgraph(self, node, subgraph, use_node_input=True, inc_subgraph_id=True):
|
|
559
|
+
if self.verbose_ > 2:
|
|
560
|
+
logger.debug(f"Inferencing subgraph of node {node.name} with output({node.output[0]}...): {node.op_type}")
|
|
561
|
+
# node inputs are not passed directly to the subgraph
|
|
562
|
+
# it's up to the node dispatcher to prepare subgraph input
|
|
563
|
+
# for example, with Scan/Loop, subgraph input shape would be trimmed from node input shape
|
|
564
|
+
# besides, inputs in subgraph could shadow implicit inputs
|
|
565
|
+
subgraph_inputs = {i.name for i in list(subgraph.initializer) + list(subgraph.input)}
|
|
566
|
+
subgraph_implicit_input = {name for name in self.known_vi_ if name not in subgraph_inputs}
|
|
567
|
+
tmp_graph = helper.make_graph(
|
|
568
|
+
list(subgraph.node),
|
|
569
|
+
"tmp",
|
|
570
|
+
list(subgraph.input) + [self.known_vi_[i] for i in subgraph_implicit_input],
|
|
571
|
+
[make_named_value_info(i.name) for i in subgraph.output],
|
|
572
|
+
)
|
|
573
|
+
tmp_graph.initializer.extend([i for i in self.out_mp_.graph.initializer if i.name in subgraph_implicit_input])
|
|
574
|
+
tmp_graph.initializer.extend(subgraph.initializer)
|
|
575
|
+
self.tmp_mp_.graph.CopyFrom(tmp_graph)
|
|
576
|
+
|
|
577
|
+
symbolic_shape_inference = SymbolicShapeInference(
|
|
578
|
+
self.int_max_,
|
|
579
|
+
self.auto_merge_,
|
|
580
|
+
self.guess_output_rank_,
|
|
581
|
+
self.verbose_,
|
|
582
|
+
prefix=self.prefix_ + "_" + str(self.subgraph_id_),
|
|
583
|
+
)
|
|
584
|
+
if inc_subgraph_id:
|
|
585
|
+
self.subgraph_id_ += 1
|
|
586
|
+
|
|
587
|
+
symbolic_shape_inference._preprocess(self.tmp_mp_)
|
|
588
|
+
symbolic_shape_inference.suggested_merge_ = self.suggested_merge_.copy()
|
|
589
|
+
while symbolic_shape_inference.run_:
|
|
590
|
+
symbolic_shape_inference._infer_impl(self.sympy_data_.copy())
|
|
591
|
+
symbolic_shape_inference._update_output_from_vi()
|
|
592
|
+
if use_node_input:
|
|
593
|
+
# if subgraph uses node input, it needs to update to merged dims
|
|
594
|
+
subgraph.ClearField("input")
|
|
595
|
+
subgraph.input.extend(symbolic_shape_inference.out_mp_.graph.input[: len(node.input)])
|
|
596
|
+
subgraph.ClearField("output")
|
|
597
|
+
subgraph.output.extend(symbolic_shape_inference.out_mp_.graph.output)
|
|
598
|
+
subgraph.ClearField("value_info")
|
|
599
|
+
subgraph.value_info.extend(symbolic_shape_inference.out_mp_.graph.value_info)
|
|
600
|
+
subgraph.ClearField("node")
|
|
601
|
+
subgraph.node.extend(symbolic_shape_inference.out_mp_.graph.node)
|
|
602
|
+
# for new symbolic dims from subgraph output, add to main graph symbolic dims
|
|
603
|
+
subgraph_shapes = [get_shape_from_value_info(o) for o in symbolic_shape_inference.out_mp_.graph.output]
|
|
604
|
+
subgraph_new_symbolic_dims = {
|
|
605
|
+
d for s in subgraph_shapes if s for d in s if type(d) is str and d not in self.symbolic_dims_
|
|
606
|
+
}
|
|
607
|
+
new_dims = {}
|
|
608
|
+
for d in subgraph_new_symbolic_dims:
|
|
609
|
+
assert d in symbolic_shape_inference.symbolic_dims_
|
|
610
|
+
new_dims[d] = symbolic_shape_inference.symbolic_dims_[d]
|
|
611
|
+
self.symbolic_dims_.update(new_dims)
|
|
612
|
+
return symbolic_shape_inference
|
|
613
|
+
|
|
614
|
+
def _get_int_or_float_values(self, node, broadcast=False, allow_float_values=False):
|
|
615
|
+
def int_or_float(value, allow_float_values):
|
|
616
|
+
# If casting into int has precision loss: keep float output
|
|
617
|
+
if allow_float_values and value % 1 != 0:
|
|
618
|
+
return value
|
|
619
|
+
return int(value)
|
|
620
|
+
|
|
621
|
+
values = [self._try_get_value(node, i) for i in range(len(node.input))]
|
|
622
|
+
if all(v is not None for v in values):
|
|
623
|
+
# some shape compute is in floating point, cast to int for sympy
|
|
624
|
+
for i, v in enumerate(values):
|
|
625
|
+
if type(v) is not np.ndarray:
|
|
626
|
+
continue
|
|
627
|
+
if len(v.shape) > 1:
|
|
628
|
+
new_v = None # ignore value for rank > 1
|
|
629
|
+
elif len(v.shape) == 0:
|
|
630
|
+
new_v = int_or_float(v.item(), allow_float_values)
|
|
631
|
+
else:
|
|
632
|
+
assert len(v.shape) == 1
|
|
633
|
+
new_v = [int_or_float(vv, allow_float_values) for vv in v]
|
|
634
|
+
values[i] = new_v
|
|
635
|
+
values_len = [len(v) if isinstance(v, list) else 0 for v in values]
|
|
636
|
+
max_len = max(values_len)
|
|
637
|
+
if max_len >= 1 and broadcast:
|
|
638
|
+
# broadcast
|
|
639
|
+
for i, v in enumerate(values):
|
|
640
|
+
if v is None:
|
|
641
|
+
continue # don't broadcast if value is unknown
|
|
642
|
+
if isinstance(v, list):
|
|
643
|
+
if len(v) < max_len:
|
|
644
|
+
values[i] = v * max_len
|
|
645
|
+
else:
|
|
646
|
+
assert len(v) == max_len
|
|
647
|
+
else:
|
|
648
|
+
values[i] = [v] * max_len
|
|
649
|
+
return values
|
|
650
|
+
|
|
651
|
+
def _compute_on_sympy_data(self, node, op_func):
|
|
652
|
+
assert len(node.output) == 1
|
|
653
|
+
|
|
654
|
+
# Before mul & div operations
|
|
655
|
+
# cast inputs into interger might lose decimal part and reduce precision
|
|
656
|
+
# keep them as float, finish the operation, then cast the result into integer
|
|
657
|
+
if node.op_type in ["Mul", "Div"]:
|
|
658
|
+
values = self._get_int_or_float_values(node, broadcast=True, allow_float_values=True)
|
|
659
|
+
else:
|
|
660
|
+
values = self._get_int_or_float_values(node, broadcast=True)
|
|
661
|
+
|
|
662
|
+
if all(v is not None for v in values):
|
|
663
|
+
is_list = [isinstance(v, list) for v in values]
|
|
664
|
+
as_list = any(is_list)
|
|
665
|
+
if as_list:
|
|
666
|
+
self.sympy_data_[node.output[0]] = [op_func(vs) for vs in zip(*values, strict=False)]
|
|
667
|
+
else:
|
|
668
|
+
self.sympy_data_[node.output[0]] = op_func(values)
|
|
669
|
+
|
|
670
|
+
def _pass_on_sympy_data(self, node):
|
|
671
|
+
assert len(node.input) == 1 or node.op_type in [
|
|
672
|
+
"Reshape",
|
|
673
|
+
"Unsqueeze",
|
|
674
|
+
"Squeeze",
|
|
675
|
+
]
|
|
676
|
+
self._compute_on_sympy_data(node, lambda x: x[0])
|
|
677
|
+
|
|
678
|
+
def _pass_on_shape_and_type(self, node):
|
|
679
|
+
vi = self.known_vi_[node.output[0]]
|
|
680
|
+
vi.CopyFrom(
|
|
681
|
+
helper.make_tensor_value_info(
|
|
682
|
+
node.output[0],
|
|
683
|
+
get_elem_type_from_type_proto(self.known_vi_[node.input[0]].type),
|
|
684
|
+
self._get_shape(node, 0),
|
|
685
|
+
)
|
|
686
|
+
)
|
|
687
|
+
|
|
688
|
+
def _new_symbolic_dim(self, prefix, dim):
|
|
689
|
+
new_dim = f"{prefix}_d{dim}"
|
|
690
|
+
if new_dim in self.suggested_merge_:
|
|
691
|
+
v = self.suggested_merge_[new_dim]
|
|
692
|
+
new_symbolic_dim = sympy.Integer(int(v)) if is_literal(v) else v
|
|
693
|
+
else:
|
|
694
|
+
new_symbolic_dim = sympy.Symbol(new_dim, integer=True, nonnegative=True)
|
|
695
|
+
self.symbolic_dims_[new_dim] = new_symbolic_dim
|
|
696
|
+
return new_symbolic_dim
|
|
697
|
+
|
|
698
|
+
def _new_symbolic_dim_from_output(self, node, out_idx=0, dim=0):
|
|
699
|
+
return self._new_symbolic_dim(
|
|
700
|
+
f"{node.op_type}{self.prefix_}_{list(self.out_mp_.graph.node).index(node)}_o{out_idx}_",
|
|
701
|
+
dim,
|
|
702
|
+
)
|
|
703
|
+
|
|
704
|
+
def _new_symbolic_shape(self, rank, node, out_idx=0):
|
|
705
|
+
return [self._new_symbolic_dim_from_output(node, out_idx, i) for i in range(rank)]
|
|
706
|
+
|
|
707
|
+
def _compute_conv_pool_shape(self, node, channels_last=False):
|
|
708
|
+
sympy_shape = self._get_sympy_shape(node, 0)
|
|
709
|
+
if len(node.input) > 1:
|
|
710
|
+
W_shape = self._get_sympy_shape(node, 1) # noqa: N806
|
|
711
|
+
rank = len(W_shape) - 2 # number of spatial axes
|
|
712
|
+
kernel_shape = W_shape[-rank - 1 : -1] if channels_last else W_shape[-rank:]
|
|
713
|
+
sympy_shape[3 if channels_last else 1] = W_shape[0]
|
|
714
|
+
else:
|
|
715
|
+
W_shape = None # noqa: N806
|
|
716
|
+
kernel_shape = get_attribute(node, "kernel_shape")
|
|
717
|
+
rank = len(kernel_shape)
|
|
718
|
+
|
|
719
|
+
assert len(sympy_shape) == rank + 2
|
|
720
|
+
|
|
721
|
+
# only need to symbolic shape inference if input has symbolic dims in spatial axes
|
|
722
|
+
spatial_shape = sympy_shape[-rank - 1 : -1] if channels_last else sympy_shape[-rank:]
|
|
723
|
+
is_symbolic_dims = [not is_literal(i) for i in spatial_shape]
|
|
724
|
+
|
|
725
|
+
if not any(is_symbolic_dims):
|
|
726
|
+
shape = get_shape_from_value_info(self.known_vi_[node.output[0]])
|
|
727
|
+
if len(shape) > 0:
|
|
728
|
+
assert len(sympy_shape) == len(shape)
|
|
729
|
+
if channels_last:
|
|
730
|
+
sympy_shape[-rank - 1 : -1] = [sympy.Integer(d) for d in shape[-rank - 1 : -1]]
|
|
731
|
+
else:
|
|
732
|
+
sympy_shape[-rank:] = [sympy.Integer(d) for d in shape[-rank:]]
|
|
733
|
+
return sympy_shape
|
|
734
|
+
|
|
735
|
+
dilations = get_attribute(node, "dilations", [1] * rank)
|
|
736
|
+
strides = get_attribute(node, "strides", [1] * rank)
|
|
737
|
+
effective_kernel_shape = [(k - 1) * d + 1 for k, d in zip(kernel_shape, dilations, strict=False)]
|
|
738
|
+
pads = get_attribute(node, "pads")
|
|
739
|
+
if pads is None:
|
|
740
|
+
pads = [0] * (2 * rank)
|
|
741
|
+
auto_pad = get_attribute(node, "auto_pad", b"NOTSET").decode("utf-8")
|
|
742
|
+
if auto_pad != "VALID" and auto_pad != "NOTSET":
|
|
743
|
+
try:
|
|
744
|
+
residual = [sympy.Mod(d, s) for d, s in zip(sympy_shape[-rank:], strides, strict=False)]
|
|
745
|
+
total_pads = [
|
|
746
|
+
max(0, (k - s) if r == 0 else (k - r))
|
|
747
|
+
for k, s, r in zip(effective_kernel_shape, strides, residual, strict=False)
|
|
748
|
+
]
|
|
749
|
+
except TypeError: # sympy may throw TypeError: cannot determine truth value of Relational
|
|
750
|
+
total_pads = [
|
|
751
|
+
max(0, (k - s)) for k, s in zip(effective_kernel_shape, strides, strict=False)
|
|
752
|
+
] # assuming no residual if sympy throws error
|
|
753
|
+
elif auto_pad == "VALID":
|
|
754
|
+
total_pads = []
|
|
755
|
+
else:
|
|
756
|
+
total_pads = [0] * rank
|
|
757
|
+
else:
|
|
758
|
+
assert len(pads) == 2 * rank
|
|
759
|
+
total_pads = [p1 + p2 for p1, p2 in zip(pads[:rank], pads[rank:], strict=False)]
|
|
760
|
+
|
|
761
|
+
ceil_mode = get_attribute(node, "ceil_mode", 0)
|
|
762
|
+
for i in range(rank):
|
|
763
|
+
effective_input_size = sympy_shape[-rank + i + (-1 if channels_last else 0)]
|
|
764
|
+
if len(total_pads) > 0:
|
|
765
|
+
effective_input_size = effective_input_size + total_pads[i]
|
|
766
|
+
if ceil_mode:
|
|
767
|
+
strided_kernel_positions = sympy.ceiling(
|
|
768
|
+
(effective_input_size - effective_kernel_shape[i]) / strides[i]
|
|
769
|
+
)
|
|
770
|
+
else:
|
|
771
|
+
strided_kernel_positions = (effective_input_size - effective_kernel_shape[i]) // strides[i]
|
|
772
|
+
sympy_shape[-rank + i + (-1 if channels_last else 0)] = strided_kernel_positions + 1
|
|
773
|
+
return sympy_shape
|
|
774
|
+
|
|
775
|
+
def _check_merged_dims(self, dims, allow_broadcast=True):
|
|
776
|
+
if allow_broadcast:
|
|
777
|
+
dims = [d for d in dims if not (is_literal(d) and int(d) <= 1)]
|
|
778
|
+
if not all(d == dims[0] for d in dims):
|
|
779
|
+
self._add_suggested_merge(dims, apply=True)
|
|
780
|
+
|
|
781
|
+
def _compute_matmul_shape(self, node, output_dtype=None):
|
|
782
|
+
lhs_shape = self._get_shape(node, 0)
|
|
783
|
+
rhs_shape = self._get_shape(node, 1)
|
|
784
|
+
lhs_rank = len(lhs_shape)
|
|
785
|
+
rhs_rank = len(rhs_shape)
|
|
786
|
+
lhs_reduce_dim = 0
|
|
787
|
+
rhs_reduce_dim = 0
|
|
788
|
+
assert lhs_rank > 0 and rhs_rank > 0
|
|
789
|
+
if lhs_rank == 1 and rhs_rank == 1:
|
|
790
|
+
new_shape = []
|
|
791
|
+
elif lhs_rank == 1:
|
|
792
|
+
rhs_reduce_dim = -2
|
|
793
|
+
new_shape = [*rhs_shape[:rhs_reduce_dim], rhs_shape[-1]]
|
|
794
|
+
elif rhs_rank == 1:
|
|
795
|
+
lhs_reduce_dim = -1
|
|
796
|
+
new_shape = lhs_shape[:lhs_reduce_dim]
|
|
797
|
+
else:
|
|
798
|
+
lhs_reduce_dim = -1
|
|
799
|
+
rhs_reduce_dim = -2
|
|
800
|
+
new_shape = [*self._broadcast_shapes(lhs_shape[:-2], rhs_shape[:-2]), lhs_shape[-2], rhs_shape[-1]]
|
|
801
|
+
# merge reduce dim
|
|
802
|
+
self._check_merged_dims(
|
|
803
|
+
[lhs_shape[lhs_reduce_dim], rhs_shape[rhs_reduce_dim]],
|
|
804
|
+
allow_broadcast=False,
|
|
805
|
+
)
|
|
806
|
+
if output_dtype is None:
|
|
807
|
+
# infer output_dtype from input type when not specified
|
|
808
|
+
output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type
|
|
809
|
+
vi = self.known_vi_[node.output[0]]
|
|
810
|
+
vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, new_shape))
|
|
811
|
+
|
|
812
|
+
def _fuse_tensor_type(self, node, out_idx, dst_type, src_type):
|
|
813
|
+
"""
|
|
814
|
+
update dst_tensor_type to be compatible with src_tensor_type when dimension mismatches
|
|
815
|
+
"""
|
|
816
|
+
dst_tensor_type = (
|
|
817
|
+
dst_type.sequence_type.elem_type.tensor_type if is_sequence(dst_type) else dst_type.tensor_type
|
|
818
|
+
)
|
|
819
|
+
src_tensor_type = (
|
|
820
|
+
src_type.sequence_type.elem_type.tensor_type if is_sequence(src_type) else src_type.tensor_type
|
|
821
|
+
)
|
|
822
|
+
if dst_tensor_type.elem_type != src_tensor_type.elem_type:
|
|
823
|
+
node_id = node.name if node.name else node.op_type
|
|
824
|
+
raise ValueError(
|
|
825
|
+
f"For node {node_id}, dst_tensor_type.elem_type != src_tensor_type.elem_type: "
|
|
826
|
+
f"{onnx.onnx_pb.TensorProto.DataType.Name(dst_tensor_type.elem_type)} vs "
|
|
827
|
+
f"{onnx.onnx_pb.TensorProto.DataType.Name(src_tensor_type.elem_type)}"
|
|
828
|
+
)
|
|
829
|
+
if dst_tensor_type.HasField("shape"):
|
|
830
|
+
for di, ds in enumerate(zip(dst_tensor_type.shape.dim, src_tensor_type.shape.dim, strict=False)):
|
|
831
|
+
if ds[0] != ds[1]:
|
|
832
|
+
# create a new symbolic dimension for node/out_idx/mismatch dim id in dst_tensor_type for tensor_type
|
|
833
|
+
# for sequence_type, clear the dimension
|
|
834
|
+
new_dim = onnx.TensorShapeProto.Dimension()
|
|
835
|
+
if not is_sequence(dst_type):
|
|
836
|
+
new_dim.dim_param = str(self._new_symbolic_dim_from_output(node, out_idx, di))
|
|
837
|
+
dst_tensor_type.shape.dim[di].CopyFrom(new_dim)
|
|
838
|
+
else:
|
|
839
|
+
dst_tensor_type.CopyFrom(src_tensor_type)
|
|
840
|
+
|
|
841
|
+
def _infer_ArrayFeatureExtractor(self, node): # noqa: N802
|
|
842
|
+
data_shape = self._get_shape(node, 0)
|
|
843
|
+
indices_shape = self._get_shape(node, 1)
|
|
844
|
+
vi = self.known_vi_[node.output[0]]
|
|
845
|
+
vi.CopyFrom(
|
|
846
|
+
helper.make_tensor_value_info(
|
|
847
|
+
node.output[0],
|
|
848
|
+
self.known_vi_[node.input[0]].type.tensor_type.elem_type,
|
|
849
|
+
data_shape[:-1] + indices_shape,
|
|
850
|
+
)
|
|
851
|
+
)
|
|
852
|
+
|
|
853
|
+
def _infer_symbolic_compute_ops(self, node):
|
|
854
|
+
funcs = {
|
|
855
|
+
"Add": lambda l: l[0] + l[1], # noqa: E741
|
|
856
|
+
"Div": lambda l: ( # noqa: E741
|
|
857
|
+
int(l[0] // l[1]) if isinstance(l[0] // l[1], float) else l[0] // l[1]
|
|
858
|
+
), # integer div in sympy
|
|
859
|
+
"Equal": lambda l: l[0] == l[1], # noqa: E741
|
|
860
|
+
"Floor": lambda l: sympy.floor(l[0]), # noqa: E741
|
|
861
|
+
"Max": lambda l: ( # noqa: E741
|
|
862
|
+
l[1]
|
|
863
|
+
if is_literal(l[0]) and int(l[0]) < -self.int_max_
|
|
864
|
+
else (l[0] if is_literal(l[1]) and int(l[1]) < -self.int_max_ else sympy.Max(l[0], l[1]))
|
|
865
|
+
),
|
|
866
|
+
"Min": lambda l: ( # noqa: E741
|
|
867
|
+
l[1]
|
|
868
|
+
if is_literal(l[0]) and int(l[0]) > self.int_max_
|
|
869
|
+
else (l[0] if is_literal(l[1]) and int(l[1]) > self.int_max_ else sympy.Min(l[0], l[1]))
|
|
870
|
+
),
|
|
871
|
+
"Mul": lambda l: int(l[0] * l[1]) if isinstance(l[0] * l[1], float) else l[0] * l[1], # noqa: E741
|
|
872
|
+
"Sub": lambda l: l[0] - l[1], # noqa: E741
|
|
873
|
+
"Where": lambda l: l[1] if l[0] else l[2], # noqa: E741
|
|
874
|
+
"Neg": lambda l: -l[0], # noqa: E741
|
|
875
|
+
}
|
|
876
|
+
assert node.op_type in funcs
|
|
877
|
+
self._compute_on_sympy_data(node, funcs[node.op_type])
|
|
878
|
+
|
|
879
|
+
def _infer_Cast(self, node): # noqa: N802
|
|
880
|
+
self._pass_on_sympy_data(node)
|
|
881
|
+
|
|
882
|
+
def _infer_CategoryMapper(self, node): # noqa: N802
|
|
883
|
+
input_type = self.known_vi_[node.input[0]].type.tensor_type.elem_type
|
|
884
|
+
if input_type == onnx.TensorProto.STRING:
|
|
885
|
+
output_type = onnx.TensorProto.INT64
|
|
886
|
+
else:
|
|
887
|
+
output_type = onnx.TensorProto.STRING
|
|
888
|
+
vi = self.known_vi_[node.output[0]]
|
|
889
|
+
vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_type, self._get_shape(node, 0)))
|
|
890
|
+
|
|
891
|
+
def _infer_Compress(self, node): # noqa: N802
|
|
892
|
+
input_shape = self._get_shape(node, 0)
|
|
893
|
+
# create a new symbolic dimension for Compress output
|
|
894
|
+
compress_len = str(self._new_symbolic_dim_from_output(node))
|
|
895
|
+
axis = get_attribute(node, "axis")
|
|
896
|
+
if axis is None:
|
|
897
|
+
# when axis is not specified, input is flattened before compress so output is 1D
|
|
898
|
+
output_shape = [compress_len]
|
|
899
|
+
else:
|
|
900
|
+
output_shape = input_shape
|
|
901
|
+
output_shape[handle_negative_axis(axis, len(input_shape))] = compress_len
|
|
902
|
+
vi = self.known_vi_[node.output[0]]
|
|
903
|
+
vi.CopyFrom(
|
|
904
|
+
helper.make_tensor_value_info(
|
|
905
|
+
node.output[0],
|
|
906
|
+
self.known_vi_[node.input[0]].type.tensor_type.elem_type,
|
|
907
|
+
output_shape,
|
|
908
|
+
)
|
|
909
|
+
)
|
|
910
|
+
|
|
911
|
+
def _infer_Concat(self, node): # noqa: N802
|
|
912
|
+
if any(i in self.sympy_data_ or i in self.initializers_ for i in node.input):
|
|
913
|
+
values = self._get_int_or_float_values(node)
|
|
914
|
+
if all(v is not None for v in values):
|
|
915
|
+
assert get_attribute(node, "axis") == 0
|
|
916
|
+
self.sympy_data_[node.output[0]] = []
|
|
917
|
+
for i in range(len(node.input)):
|
|
918
|
+
value = values[i]
|
|
919
|
+
if isinstance(value, list):
|
|
920
|
+
self.sympy_data_[node.output[0]].extend(value)
|
|
921
|
+
else:
|
|
922
|
+
self.sympy_data_[node.output[0]].append(value)
|
|
923
|
+
|
|
924
|
+
sympy_shape = self._get_sympy_shape(node, 0)
|
|
925
|
+
axis = handle_negative_axis(get_attribute(node, "axis"), len(sympy_shape))
|
|
926
|
+
for i_idx in range(1, len(node.input)):
|
|
927
|
+
input_shape = self._get_sympy_shape(node, i_idx)
|
|
928
|
+
if input_shape:
|
|
929
|
+
sympy_shape[axis] = sympy_shape[axis] + input_shape[axis]
|
|
930
|
+
self._update_computed_dims(sympy_shape)
|
|
931
|
+
# merge symbolic dims for non-concat axes
|
|
932
|
+
for d in range(len(sympy_shape)):
|
|
933
|
+
if d == axis:
|
|
934
|
+
continue
|
|
935
|
+
dims = [self._get_shape(node, i_idx)[d] for i_idx in range(len(node.input)) if self._get_shape(node, i_idx)]
|
|
936
|
+
if all(d == dims[0] for d in dims):
|
|
937
|
+
continue
|
|
938
|
+
merged = self._merge_symbols(dims)
|
|
939
|
+
if type(merged) is str:
|
|
940
|
+
sympy_shape[d] = self.symbolic_dims_[merged] if merged else None
|
|
941
|
+
else:
|
|
942
|
+
sympy_shape[d] = merged
|
|
943
|
+
vi = self.known_vi_[node.output[0]]
|
|
944
|
+
vi.CopyFrom(
|
|
945
|
+
helper.make_tensor_value_info(
|
|
946
|
+
node.output[0],
|
|
947
|
+
self.known_vi_[node.input[0]].type.tensor_type.elem_type,
|
|
948
|
+
get_shape_from_sympy_shape(sympy_shape),
|
|
949
|
+
)
|
|
950
|
+
)
|
|
951
|
+
|
|
952
|
+
def _infer_ConcatFromSequence(self, node): # noqa: N802
|
|
953
|
+
seq_shape = self._get_shape(node, 0)
|
|
954
|
+
new_axis = 1 if get_attribute(node, "new_axis") else 0
|
|
955
|
+
axis = handle_negative_axis(get_attribute(node, "axis"), len(seq_shape) + new_axis)
|
|
956
|
+
concat_dim = str(self._new_symbolic_dim_from_output(node, 0, axis))
|
|
957
|
+
new_shape = seq_shape
|
|
958
|
+
if new_axis:
|
|
959
|
+
new_shape = [*seq_shape[:axis], concat_dim, *seq_shape[axis:]]
|
|
960
|
+
else:
|
|
961
|
+
new_shape[axis] = concat_dim
|
|
962
|
+
vi = self.known_vi_[node.output[0]]
|
|
963
|
+
vi.CopyFrom(
|
|
964
|
+
helper.make_tensor_value_info(
|
|
965
|
+
node.output[0],
|
|
966
|
+
self.known_vi_[node.input[0]].type.sequence_type.elem_type.tensor_type.elem_type,
|
|
967
|
+
new_shape,
|
|
968
|
+
)
|
|
969
|
+
)
|
|
970
|
+
|
|
971
|
+
def _infer_Constant(self, node): # noqa: N802
|
|
972
|
+
t = get_attribute(node, "value")
|
|
973
|
+
self.sympy_data_[node.output[0]] = numpy_helper.to_array(t)
|
|
974
|
+
|
|
975
|
+
def _infer_ConstantOfShape(self, node): # noqa: N802
|
|
976
|
+
sympy_shape = self._get_int_or_float_values(node)[0]
|
|
977
|
+
vi = self.known_vi_[node.output[0]]
|
|
978
|
+
if sympy_shape is not None:
|
|
979
|
+
if type(sympy_shape) != list: # noqa: E721
|
|
980
|
+
sympy_shape = [sympy_shape]
|
|
981
|
+
self._update_computed_dims(sympy_shape)
|
|
982
|
+
# update sympy data if output type is int, and shape is known
|
|
983
|
+
if vi.type.tensor_type.elem_type == onnx.TensorProto.INT64 and all(is_literal(x) for x in sympy_shape):
|
|
984
|
+
self.sympy_data_[node.output[0]] = np.ones(
|
|
985
|
+
[int(x) for x in sympy_shape], dtype=np.int64
|
|
986
|
+
) * numpy_helper.to_array(get_attribute(node, "value", 0))
|
|
987
|
+
else:
|
|
988
|
+
# create new dynamic shape
|
|
989
|
+
# note input0 is a 1D vector of shape, the new symbolic shape has the rank of the shape vector length
|
|
990
|
+
sympy_shape = self._new_symbolic_shape(self._get_shape(node, 0)[0], node)
|
|
991
|
+
|
|
992
|
+
vi.CopyFrom(
|
|
993
|
+
helper.make_tensor_value_info(
|
|
994
|
+
node.output[0],
|
|
995
|
+
vi.type.tensor_type.elem_type,
|
|
996
|
+
get_shape_from_sympy_shape(sympy_shape),
|
|
997
|
+
)
|
|
998
|
+
)
|
|
999
|
+
|
|
1000
|
+
def _infer_Conv(self, node): # noqa: N802
|
|
1001
|
+
sympy_shape = self._compute_conv_pool_shape(node)
|
|
1002
|
+
self._update_computed_dims(sympy_shape)
|
|
1003
|
+
vi = self.known_vi_[node.output[0]]
|
|
1004
|
+
vi.CopyFrom(
|
|
1005
|
+
helper.make_tensor_value_info(
|
|
1006
|
+
node.output[0],
|
|
1007
|
+
vi.type.tensor_type.elem_type,
|
|
1008
|
+
get_shape_from_sympy_shape(sympy_shape),
|
|
1009
|
+
)
|
|
1010
|
+
)
|
|
1011
|
+
|
|
1012
|
+
def _infer_NhwcConv(self, node): # noqa: N802
|
|
1013
|
+
sympy_shape = self._compute_conv_pool_shape(node, channels_last=True)
|
|
1014
|
+
self._update_computed_dims(sympy_shape)
|
|
1015
|
+
vi = self.known_vi_[node.output[0]]
|
|
1016
|
+
vi.CopyFrom(
|
|
1017
|
+
helper.make_tensor_value_info(
|
|
1018
|
+
node.output[0],
|
|
1019
|
+
self.known_vi_[node.input[0]].type.tensor_type.elem_type,
|
|
1020
|
+
get_shape_from_sympy_shape(sympy_shape),
|
|
1021
|
+
)
|
|
1022
|
+
)
|
|
1023
|
+
|
|
1024
|
+
def _infer_DequantizeLinear(self, node): # noqa: N802
|
|
1025
|
+
# Get the output data type from the scale input (index 1, required).
|
|
1026
|
+
output_dtype = self.known_vi_[node.input[1]].type.tensor_type.elem_type
|
|
1027
|
+
|
|
1028
|
+
# Get the output shape from the first input.
|
|
1029
|
+
output_shape = self._get_shape(node, 0)
|
|
1030
|
+
|
|
1031
|
+
vi = self.known_vi_[node.output[0]]
|
|
1032
|
+
vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape))
|
|
1033
|
+
|
|
1034
|
+
def _infer_QuantizeLinear(self, node): # noqa: N802
|
|
1035
|
+
# Get the output data type from the zero-point input (index 2, optional).
|
|
1036
|
+
# Otherwise, default to uint8
|
|
1037
|
+
output_dtype = onnx.TensorProto.UINT8
|
|
1038
|
+
if len(node.input) > 2 and node.input[2]:
|
|
1039
|
+
output_dtype = self.known_vi_[node.input[2]].type.tensor_type.elem_type
|
|
1040
|
+
|
|
1041
|
+
# Get the output shape from the first input.
|
|
1042
|
+
output_shape = self._get_shape(node, 0)
|
|
1043
|
+
|
|
1044
|
+
vi = self.known_vi_[node.output[0]]
|
|
1045
|
+
vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape))
|
|
1046
|
+
|
|
1047
|
+
def _infer_QLinearBinary(self, node): # noqa: N802
|
|
1048
|
+
# Get the output data type from the first input to QLinearAdd / QLinearMul.
|
|
1049
|
+
output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type
|
|
1050
|
+
|
|
1051
|
+
# The inputs are first and fourth operands respectively.
|
|
1052
|
+
input_1_shape = self._get_shape(node, 0)
|
|
1053
|
+
input_2_shape = self._get_shape(node, 3)
|
|
1054
|
+
|
|
1055
|
+
# Compute the broadcasted shape
|
|
1056
|
+
new_shape = self._broadcast_shapes(input_1_shape, input_2_shape)
|
|
1057
|
+
|
|
1058
|
+
vi = self.known_vi_[node.output[0]]
|
|
1059
|
+
vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, new_shape))
|
|
1060
|
+
|
|
1061
|
+
def _infer_Einsum(self, node): # noqa: N802
|
|
1062
|
+
# ref:https://github.com/onnx/onnx/blob/623dfaa0151b2e4ce49779c3ec31cbd78c592b80/onnx/defs/math/defs.cc#L3275
|
|
1063
|
+
equation = get_attribute(node, "equation")
|
|
1064
|
+
equation = equation.replace(b" ", b"")
|
|
1065
|
+
mid_index = equation.find(b"->")
|
|
1066
|
+
left_equation = equation[:mid_index] if mid_index != -1 else equation
|
|
1067
|
+
|
|
1068
|
+
num_operands = 0
|
|
1069
|
+
num_ellipsis = 0
|
|
1070
|
+
num_ellipsis_indices = 0
|
|
1071
|
+
|
|
1072
|
+
letter_to_dim = {}
|
|
1073
|
+
|
|
1074
|
+
terms = left_equation.split(b",")
|
|
1075
|
+
for term in terms:
|
|
1076
|
+
ellipsis_index = term.find(b"...")
|
|
1077
|
+
shape = self._get_shape(node, num_operands)
|
|
1078
|
+
rank = len(shape)
|
|
1079
|
+
if ellipsis_index != -1:
|
|
1080
|
+
if num_ellipsis == 0:
|
|
1081
|
+
num_ellipsis_indices = rank - len(term) + 3
|
|
1082
|
+
num_ellipsis = num_ellipsis + 1
|
|
1083
|
+
for i in range(1, rank + 1):
|
|
1084
|
+
letter = term[-i]
|
|
1085
|
+
if letter != 46: # letter != b'.'
|
|
1086
|
+
dim = shape[-i]
|
|
1087
|
+
if letter not in letter_to_dim:
|
|
1088
|
+
letter_to_dim[letter] = dim
|
|
1089
|
+
elif type(dim) is not sympy.Symbol:
|
|
1090
|
+
letter_to_dim[letter] = dim
|
|
1091
|
+
num_operands = num_operands + 1
|
|
1092
|
+
|
|
1093
|
+
new_sympy_shape = []
|
|
1094
|
+
from collections import OrderedDict # noqa: PLC0415
|
|
1095
|
+
|
|
1096
|
+
num_letter_occurrences = OrderedDict()
|
|
1097
|
+
if mid_index != -1:
|
|
1098
|
+
right_equation = equation[mid_index + 2 :]
|
|
1099
|
+
right_ellipsis_index = right_equation.find(b"...")
|
|
1100
|
+
if right_ellipsis_index != -1:
|
|
1101
|
+
for i in range(num_ellipsis_indices):
|
|
1102
|
+
new_sympy_shape.append(shape[i])
|
|
1103
|
+
for c in right_equation:
|
|
1104
|
+
if c != 46: # c != b'.'
|
|
1105
|
+
new_sympy_shape.append(letter_to_dim[c])
|
|
1106
|
+
else:
|
|
1107
|
+
for i in range(num_ellipsis_indices):
|
|
1108
|
+
new_sympy_shape.append(shape[i])
|
|
1109
|
+
for c in left_equation:
|
|
1110
|
+
if c != 44 and c != 46: # c != b',' and c != b'.':
|
|
1111
|
+
if c in num_letter_occurrences:
|
|
1112
|
+
num_letter_occurrences[c] = num_letter_occurrences[c] + 1
|
|
1113
|
+
else:
|
|
1114
|
+
num_letter_occurrences[c] = 1
|
|
1115
|
+
for key, value in num_letter_occurrences.items():
|
|
1116
|
+
if value == 1:
|
|
1117
|
+
new_sympy_shape.append(letter_to_dim[key])
|
|
1118
|
+
|
|
1119
|
+
output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type
|
|
1120
|
+
vi = self.known_vi_[node.output[0]]
|
|
1121
|
+
vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, new_sympy_shape))
|
|
1122
|
+
|
|
1123
|
+
def _infer_Expand(self, node): # noqa: N802
|
|
1124
|
+
expand_to_shape = as_list(self._try_get_value(node, 1), keep_none=True)
|
|
1125
|
+
if expand_to_shape is not None:
|
|
1126
|
+
# new_shape's dim can come from shape value
|
|
1127
|
+
self._update_computed_dims(expand_to_shape)
|
|
1128
|
+
shape = self._get_shape(node, 0)
|
|
1129
|
+
new_shape = self._broadcast_shapes(shape, get_shape_from_sympy_shape(expand_to_shape))
|
|
1130
|
+
vi = self.known_vi_[node.output[0]]
|
|
1131
|
+
vi.CopyFrom(
|
|
1132
|
+
helper.make_tensor_value_info(
|
|
1133
|
+
node.output[0],
|
|
1134
|
+
self.known_vi_[node.input[0]].type.tensor_type.elem_type,
|
|
1135
|
+
new_shape,
|
|
1136
|
+
)
|
|
1137
|
+
)
|
|
1138
|
+
|
|
1139
|
+
def _infer_Gather(self, node): # noqa: N802
|
|
1140
|
+
data_shape = self._get_shape(node, 0)
|
|
1141
|
+
axis = handle_negative_axis(get_attribute(node, "axis", 0), len(data_shape))
|
|
1142
|
+
indices_shape = self._get_shape(node, 1)
|
|
1143
|
+
vi = self.known_vi_[node.output[0]]
|
|
1144
|
+
if node.op_type == "Gather":
|
|
1145
|
+
elem_type = self.known_vi_[node.input[0]].type.tensor_type.elem_type
|
|
1146
|
+
elif node.op_type == "GatherBlockQuantized":
|
|
1147
|
+
# scales
|
|
1148
|
+
elem_type = self.known_vi_[node.input[2]].type.tensor_type.elem_type
|
|
1149
|
+
else:
|
|
1150
|
+
raise ValueError(f"Unsupported Gather op_type: {node.op_type}")
|
|
1151
|
+
vi.CopyFrom(
|
|
1152
|
+
helper.make_tensor_value_info(
|
|
1153
|
+
node.output[0],
|
|
1154
|
+
elem_type,
|
|
1155
|
+
data_shape[:axis] + indices_shape + data_shape[axis + 1 :],
|
|
1156
|
+
)
|
|
1157
|
+
)
|
|
1158
|
+
# for 1D input, do some sympy compute
|
|
1159
|
+
if node.input[0] in self.sympy_data_ and len(data_shape) == 1 and get_attribute(node, "axis", 0) == 0:
|
|
1160
|
+
idx = self._try_get_value(node, 1)
|
|
1161
|
+
if idx is not None:
|
|
1162
|
+
data = self.sympy_data_[node.input[0]]
|
|
1163
|
+
if type(data) is list:
|
|
1164
|
+
if type(idx) is np.ndarray and len(idx.shape) == 1:
|
|
1165
|
+
self.sympy_data_[node.output[0]] = [data[int(i)] for i in idx]
|
|
1166
|
+
else:
|
|
1167
|
+
self.sympy_data_[node.output[0]] = data[int(idx)]
|
|
1168
|
+
else:
|
|
1169
|
+
assert idx == 0 or idx == -1
|
|
1170
|
+
self.sympy_data_[node.output[0]] = data
|
|
1171
|
+
|
|
1172
|
+
def _infer_GatherElements(self, node): # noqa: N802
|
|
1173
|
+
indices_shape = self._get_shape(node, 1)
|
|
1174
|
+
vi = self.known_vi_[node.output[0]]
|
|
1175
|
+
vi.CopyFrom(
|
|
1176
|
+
helper.make_tensor_value_info(
|
|
1177
|
+
node.output[0],
|
|
1178
|
+
self.known_vi_[node.input[0]].type.tensor_type.elem_type,
|
|
1179
|
+
indices_shape,
|
|
1180
|
+
)
|
|
1181
|
+
)
|
|
1182
|
+
|
|
1183
|
+
def _infer_GatherND(self, node): # noqa: N802
|
|
1184
|
+
data_shape = self._get_shape(node, 0)
|
|
1185
|
+
data_rank = len(data_shape)
|
|
1186
|
+
indices_shape = self._get_shape(node, 1)
|
|
1187
|
+
len(indices_shape)
|
|
1188
|
+
last_index_dimension = indices_shape[-1]
|
|
1189
|
+
assert is_literal(last_index_dimension) and last_index_dimension <= data_rank
|
|
1190
|
+
new_shape = indices_shape[:-1] + data_shape[last_index_dimension:]
|
|
1191
|
+
vi = self.known_vi_[node.output[0]]
|
|
1192
|
+
vi.CopyFrom(
|
|
1193
|
+
helper.make_tensor_value_info(
|
|
1194
|
+
node.output[0],
|
|
1195
|
+
self.known_vi_[node.input[0]].type.tensor_type.elem_type,
|
|
1196
|
+
new_shape,
|
|
1197
|
+
)
|
|
1198
|
+
)
|
|
1199
|
+
|
|
1200
|
+
def _infer_If(self, node): # noqa: N802
|
|
1201
|
+
# special case for constant condition, in case there are mismatching shape from the non-executed branch
|
|
1202
|
+
subgraphs = [
|
|
1203
|
+
get_attribute(node, "then_branch"),
|
|
1204
|
+
get_attribute(node, "else_branch"),
|
|
1205
|
+
]
|
|
1206
|
+
cond = self._try_get_value(node, 0)
|
|
1207
|
+
if cond is not None:
|
|
1208
|
+
if as_scalar(cond) > 0:
|
|
1209
|
+
subgraphs[1].CopyFrom(subgraphs[0])
|
|
1210
|
+
else:
|
|
1211
|
+
subgraphs[0].CopyFrom(subgraphs[1])
|
|
1212
|
+
|
|
1213
|
+
for i_sub, subgraph in enumerate(subgraphs):
|
|
1214
|
+
subgraph_infer = self._onnx_infer_subgraph(node, subgraph, use_node_input=False)
|
|
1215
|
+
for i_out in range(len(node.output)):
|
|
1216
|
+
vi = self.known_vi_[node.output[i_out]]
|
|
1217
|
+
if i_sub == 0:
|
|
1218
|
+
vi.CopyFrom(subgraph.output[i_out])
|
|
1219
|
+
vi.name = node.output[i_out]
|
|
1220
|
+
else:
|
|
1221
|
+
self._fuse_tensor_type(node, i_out, vi.type, subgraph.output[i_out].type)
|
|
1222
|
+
|
|
1223
|
+
# pass on sympy data from subgraph, if cond is constant
|
|
1224
|
+
if cond is not None and i_sub == (0 if as_scalar(cond) > 0 else 1):
|
|
1225
|
+
if subgraph.output[i_out].name in subgraph_infer.sympy_data_:
|
|
1226
|
+
self.sympy_data_[vi.name] = subgraph_infer.sympy_data_[subgraph.output[i_out].name]
|
|
1227
|
+
|
|
1228
|
+
def _infer_Loop(self, node): # noqa: N802
|
|
1229
|
+
subgraph = get_attribute(node, "body")
|
|
1230
|
+
assert len(subgraph.input) == len(node.input)
|
|
1231
|
+
num_loop_carried = len(node.input) - 2 # minus the length and initial loop condition
|
|
1232
|
+
# when sequence_type is used as loop carried input
|
|
1233
|
+
# needs to run subgraph infer twice if the tensor shape in sequence contains None
|
|
1234
|
+
for i, si in enumerate(subgraph.input):
|
|
1235
|
+
si_name = si.name
|
|
1236
|
+
si.CopyFrom(self.known_vi_[node.input[i]])
|
|
1237
|
+
si.name = si_name
|
|
1238
|
+
|
|
1239
|
+
self._onnx_infer_subgraph(node, subgraph)
|
|
1240
|
+
|
|
1241
|
+
# check subgraph input/output for shape changes in loop carried variables
|
|
1242
|
+
# for tensor_type, create new symbolic dim when changing, i.e., output = Concat(input, a)
|
|
1243
|
+
# for sequence_type, propagate from output to input
|
|
1244
|
+
need_second_infer = False
|
|
1245
|
+
for i_out in range(1, num_loop_carried + 1):
|
|
1246
|
+
so = subgraph.output[i_out]
|
|
1247
|
+
so_shape = get_shape_from_value_info(so)
|
|
1248
|
+
if is_sequence(so.type):
|
|
1249
|
+
if so_shape and None in so_shape:
|
|
1250
|
+
# copy shape from output to input
|
|
1251
|
+
# note that loop input is [loop_len, cond, input_0, input_1, ...]
|
|
1252
|
+
# while loop output is [cond, output_0, output_1, ...]
|
|
1253
|
+
subgraph.input[i_out + 1].type.sequence_type.elem_type.CopyFrom(so.type.sequence_type.elem_type)
|
|
1254
|
+
need_second_infer = True
|
|
1255
|
+
else:
|
|
1256
|
+
si = subgraph.input[i_out + 1]
|
|
1257
|
+
si_shape = get_shape_from_value_info(si)
|
|
1258
|
+
for di, dims in enumerate(zip(si_shape, so_shape, strict=False)):
|
|
1259
|
+
if dims[0] != dims[1]:
|
|
1260
|
+
new_dim = onnx.TensorShapeProto.Dimension()
|
|
1261
|
+
new_dim.dim_param = str(self._new_symbolic_dim_from_output(node, i_out, di))
|
|
1262
|
+
si.type.tensor_type.shape.dim[di].CopyFrom(new_dim)
|
|
1263
|
+
so.type.tensor_type.shape.dim[di].CopyFrom(new_dim)
|
|
1264
|
+
need_second_infer = True
|
|
1265
|
+
|
|
1266
|
+
if need_second_infer:
|
|
1267
|
+
if self.verbose_ > 2:
|
|
1268
|
+
logger.debug(
|
|
1269
|
+
f"Rerun Loop: {node.name}({node.output[0]}...), because of sequence in loop carried variables"
|
|
1270
|
+
)
|
|
1271
|
+
self._onnx_infer_subgraph(node, subgraph, inc_subgraph_id=False)
|
|
1272
|
+
|
|
1273
|
+
# create a new symbolic dimension for iteration dependent dimension
|
|
1274
|
+
loop_iter_dim = str(self._new_symbolic_dim_from_output(node))
|
|
1275
|
+
for i in range(len(node.output)):
|
|
1276
|
+
vi = self.known_vi_[node.output[i]]
|
|
1277
|
+
vi.CopyFrom(subgraph.output[i + 1]) # first subgraph output is condition, not in node output
|
|
1278
|
+
if i >= num_loop_carried:
|
|
1279
|
+
assert not is_sequence(vi.type) # TODO: handle loop accumulation in sequence_type
|
|
1280
|
+
subgraph_vi_dim = subgraph.output[i + 1].type.tensor_type.shape.dim
|
|
1281
|
+
vi.type.tensor_type.shape.ClearField("dim")
|
|
1282
|
+
vi_dim = vi.type.tensor_type.shape.dim
|
|
1283
|
+
vi_dim.add().dim_param = loop_iter_dim
|
|
1284
|
+
vi_dim.extend(list(subgraph_vi_dim))
|
|
1285
|
+
vi.name = node.output[i]
|
|
1286
|
+
|
|
1287
|
+
def _infer_MatMul(self, node): # noqa: N802
|
|
1288
|
+
self._compute_matmul_shape(node)
|
|
1289
|
+
|
|
1290
|
+
def _infer_MatMulInteger(self, node): # noqa: N802
|
|
1291
|
+
self._compute_matmul_shape(node, onnx.TensorProto.INT32)
|
|
1292
|
+
|
|
1293
|
+
def _infer_MatMulNBits(self, node): # noqa: N802
|
|
1294
|
+
lhs_shape = self._get_shape(node, 0)
|
|
1295
|
+
rhs_shape = [get_attribute(node, "K"), get_attribute(node, "N")]
|
|
1296
|
+
lhs_rank = len(lhs_shape)
|
|
1297
|
+
assert lhs_rank > 0
|
|
1298
|
+
if lhs_rank == 1:
|
|
1299
|
+
new_shape = rhs_shape[1:]
|
|
1300
|
+
else:
|
|
1301
|
+
new_shape = lhs_shape[:-1] + rhs_shape[1:]
|
|
1302
|
+
# merge reduce dim
|
|
1303
|
+
self._check_merged_dims(
|
|
1304
|
+
[lhs_shape[-1], rhs_shape[0]],
|
|
1305
|
+
allow_broadcast=False,
|
|
1306
|
+
)
|
|
1307
|
+
# infer output_dtype from input type when not specified
|
|
1308
|
+
output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type
|
|
1309
|
+
vi = self.known_vi_[node.output[0]]
|
|
1310
|
+
vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, new_shape))
|
|
1311
|
+
|
|
1312
|
+
def _infer_NonMaxSuppression(self, node): # noqa: N802
|
|
1313
|
+
selected = str(self._new_symbolic_dim_from_output(node))
|
|
1314
|
+
vi = self.known_vi_[node.output[0]]
|
|
1315
|
+
vi.CopyFrom(helper.make_tensor_value_info(node.output[0], onnx.TensorProto.INT64, [selected, 3]))
|
|
1316
|
+
|
|
1317
|
+
def _infer_NonZero(self, node): # noqa: N802
|
|
1318
|
+
input_rank = self._get_shape_rank(node, 0)
|
|
1319
|
+
# create a new symbolic dimension for NonZero output
|
|
1320
|
+
nz_len = str(self._new_symbolic_dim_from_output(node, 0, 1))
|
|
1321
|
+
vi = self.known_vi_[node.output[0]]
|
|
1322
|
+
vi.CopyFrom(helper.make_tensor_value_info(node.output[0], vi.type.tensor_type.elem_type, [input_rank, nz_len]))
|
|
1323
|
+
|
|
1324
|
+
def _infer_OneHot(self, node): # noqa: N802
|
|
1325
|
+
sympy_shape = self._get_sympy_shape(node, 0)
|
|
1326
|
+
depth = self._try_get_value(node, 1)
|
|
1327
|
+
axis = get_attribute(node, "axis", -1)
|
|
1328
|
+
axis = handle_negative_axis(axis, len(sympy_shape) + 1)
|
|
1329
|
+
new_shape = get_shape_from_sympy_shape(
|
|
1330
|
+
[
|
|
1331
|
+
*sympy_shape[:axis],
|
|
1332
|
+
self._new_symbolic_dim_from_output(node) if not is_literal(depth) else depth,
|
|
1333
|
+
*sympy_shape[axis:],
|
|
1334
|
+
]
|
|
1335
|
+
)
|
|
1336
|
+
vi = self.known_vi_[node.output[0]]
|
|
1337
|
+
vi.CopyFrom(
|
|
1338
|
+
helper.make_tensor_value_info(
|
|
1339
|
+
node.output[0],
|
|
1340
|
+
self.known_vi_[node.input[2]].type.tensor_type.elem_type,
|
|
1341
|
+
new_shape,
|
|
1342
|
+
)
|
|
1343
|
+
)
|
|
1344
|
+
|
|
1345
|
+
def _infer_Pad(self, node): # noqa: N802
|
|
1346
|
+
if get_opset(self.out_mp_) <= 10:
|
|
1347
|
+
pads = get_attribute(node, "pads")
|
|
1348
|
+
else:
|
|
1349
|
+
pads = self._try_get_value(node, 1)
|
|
1350
|
+
|
|
1351
|
+
sympy_shape = self._get_sympy_shape(node, 0)
|
|
1352
|
+
rank = len(sympy_shape)
|
|
1353
|
+
|
|
1354
|
+
if pads is not None:
|
|
1355
|
+
assert len(pads) == 2 * rank
|
|
1356
|
+
new_sympy_shape = [
|
|
1357
|
+
d + pad_up + pad_down
|
|
1358
|
+
for d, pad_up, pad_down in zip(sympy_shape, pads[:rank], pads[rank:], strict=False)
|
|
1359
|
+
]
|
|
1360
|
+
self._update_computed_dims(new_sympy_shape)
|
|
1361
|
+
else:
|
|
1362
|
+
# dynamic pads, create new symbolic dimensions
|
|
1363
|
+
new_sympy_shape = self._new_symbolic_shape(rank, node)
|
|
1364
|
+
output_tp = self.known_vi_[node.input[0]].type.tensor_type.elem_type
|
|
1365
|
+
|
|
1366
|
+
vi = self.known_vi_[node.output[0]]
|
|
1367
|
+
vi.CopyFrom(
|
|
1368
|
+
helper.make_tensor_value_info(node.output[0], output_tp, get_shape_from_sympy_shape(new_sympy_shape))
|
|
1369
|
+
)
|
|
1370
|
+
|
|
1371
|
+
def _infer_Pool(self, node): # noqa: N802
|
|
1372
|
+
sympy_shape = self._compute_conv_pool_shape(node)
|
|
1373
|
+
self._update_computed_dims(sympy_shape)
|
|
1374
|
+
for o in node.output:
|
|
1375
|
+
if not o:
|
|
1376
|
+
continue
|
|
1377
|
+
vi = self.known_vi_[o]
|
|
1378
|
+
vi.CopyFrom(
|
|
1379
|
+
helper.make_tensor_value_info(
|
|
1380
|
+
o,
|
|
1381
|
+
vi.type.tensor_type.elem_type,
|
|
1382
|
+
get_shape_from_sympy_shape(sympy_shape),
|
|
1383
|
+
)
|
|
1384
|
+
)
|
|
1385
|
+
|
|
1386
|
+
def _infer_aten_bitwise_or(self, node):
|
|
1387
|
+
shape0 = self._get_shape(node, 0)
|
|
1388
|
+
shape1 = self._get_shape(node, 1)
|
|
1389
|
+
new_shape = self._broadcast_shapes(shape0, shape1)
|
|
1390
|
+
t0 = self.known_vi_[node.input[0]]
|
|
1391
|
+
vi = self.known_vi_[node.output[0]]
|
|
1392
|
+
vi.CopyFrom(helper.make_tensor_value_info(node.output[0], t0.type.tensor_type.elem_type, new_shape))
|
|
1393
|
+
|
|
1394
|
+
def _infer_aten_diagonal(self, node):
|
|
1395
|
+
sympy_shape = self._get_sympy_shape(node, 0)
|
|
1396
|
+
rank = len(sympy_shape)
|
|
1397
|
+
offset = self._try_get_value(node, 1)
|
|
1398
|
+
dim1 = self._try_get_value(node, 2)
|
|
1399
|
+
dim2 = self._try_get_value(node, 3)
|
|
1400
|
+
|
|
1401
|
+
assert offset is not None and dim1 is not None and dim2 is not None
|
|
1402
|
+
dim1 = handle_negative_axis(dim1, rank)
|
|
1403
|
+
dim2 = handle_negative_axis(dim2, rank)
|
|
1404
|
+
|
|
1405
|
+
new_shape = []
|
|
1406
|
+
for dim, val in enumerate(sympy_shape):
|
|
1407
|
+
if dim not in [dim1, dim2]:
|
|
1408
|
+
new_shape.append(val)
|
|
1409
|
+
|
|
1410
|
+
shape1 = sympy_shape[dim1]
|
|
1411
|
+
shape2 = sympy_shape[dim2]
|
|
1412
|
+
if offset >= 0:
|
|
1413
|
+
diag_shape = sympy.Max(0, sympy.Min(shape1, shape2 - offset))
|
|
1414
|
+
else:
|
|
1415
|
+
diag_shape = sympy.Max(0, sympy.Min(shape1 + offset, shape2))
|
|
1416
|
+
new_shape.append(diag_shape)
|
|
1417
|
+
|
|
1418
|
+
if node.output[0]:
|
|
1419
|
+
vi = self.known_vi_[node.output[0]]
|
|
1420
|
+
vi.CopyFrom(
|
|
1421
|
+
helper.make_tensor_value_info(
|
|
1422
|
+
node.output[0],
|
|
1423
|
+
self.known_vi_[node.input[0]].type.tensor_type.elem_type,
|
|
1424
|
+
get_shape_from_sympy_shape(new_shape),
|
|
1425
|
+
)
|
|
1426
|
+
)
|
|
1427
|
+
|
|
1428
|
+
def _infer_aten_multinomial(self, node):
|
|
1429
|
+
sympy_shape = self._get_sympy_shape(node, 0)
|
|
1430
|
+
rank = len(sympy_shape)
|
|
1431
|
+
assert rank in [1, 2]
|
|
1432
|
+
num_samples = self._try_get_value(node, 1)
|
|
1433
|
+
di = rank - 1
|
|
1434
|
+
last_dim = num_samples if num_samples else str(self._new_symbolic_dim_from_output(node, 0, di))
|
|
1435
|
+
output_shape = [*sympy_shape[:-1], last_dim]
|
|
1436
|
+
vi = self.known_vi_[node.output[0]]
|
|
1437
|
+
vi.CopyFrom(
|
|
1438
|
+
helper.make_tensor_value_info(
|
|
1439
|
+
node.output[0],
|
|
1440
|
+
onnx.TensorProto.INT64,
|
|
1441
|
+
get_shape_from_sympy_shape(output_shape),
|
|
1442
|
+
)
|
|
1443
|
+
)
|
|
1444
|
+
|
|
1445
|
+
def _infer_aten_pool2d(self, node):
|
|
1446
|
+
sympy_shape = self._get_sympy_shape(node, 0)
|
|
1447
|
+
assert len(sympy_shape) == 4
|
|
1448
|
+
sympy_shape[-2:] = [self._new_symbolic_dim_from_output(node, 0, i) for i in [2, 3]]
|
|
1449
|
+
self._update_computed_dims(sympy_shape)
|
|
1450
|
+
for i, o in enumerate(node.output):
|
|
1451
|
+
if not o:
|
|
1452
|
+
continue
|
|
1453
|
+
vi = self.known_vi_[o]
|
|
1454
|
+
elem_type = onnx.TensorProto.INT64 if i == 1 else self.known_vi_[node.input[0]].type.tensor_type.elem_type
|
|
1455
|
+
vi.CopyFrom(helper.make_tensor_value_info(o, elem_type, get_shape_from_sympy_shape(sympy_shape)))
|
|
1456
|
+
|
|
1457
|
+
def _infer_aten_minmax(self, node):
|
|
1458
|
+
vi = self.known_vi_[node.output[0]]
|
|
1459
|
+
if len(node.input) == 1:
|
|
1460
|
+
vi.CopyFrom(
|
|
1461
|
+
helper.make_tensor_value_info(
|
|
1462
|
+
node.output[0], self.known_vi_[node.input[0]].type.tensor_type.elem_type, []
|
|
1463
|
+
)
|
|
1464
|
+
)
|
|
1465
|
+
else:
|
|
1466
|
+
assert len(node.input) == 3
|
|
1467
|
+
keepdim = self._try_get_value(node, 2)
|
|
1468
|
+
assert keepdim is not None # can only handle known keepdim case.
|
|
1469
|
+
dim = self._try_get_value(node, 1)
|
|
1470
|
+
if dim is None:
|
|
1471
|
+
rank = self._get_shape_rank(node, 0)
|
|
1472
|
+
output_shape = self._new_symbolic_shape(rank if keepdim else rank - 1, node)
|
|
1473
|
+
else:
|
|
1474
|
+
shape = self._get_sympy_shape(node, 0)
|
|
1475
|
+
dim = handle_negative_axis(dim, len(shape))
|
|
1476
|
+
output_shape = shape[:dim]
|
|
1477
|
+
if keepdim:
|
|
1478
|
+
output_shape += [1]
|
|
1479
|
+
output_shape += shape[dim + 1 :]
|
|
1480
|
+
|
|
1481
|
+
output_shape = get_shape_from_sympy_shape(output_shape)
|
|
1482
|
+
vi.CopyFrom(
|
|
1483
|
+
helper.make_tensor_value_info(
|
|
1484
|
+
node.output[0], self.known_vi_[node.input[0]].type.tensor_type.elem_type, output_shape
|
|
1485
|
+
)
|
|
1486
|
+
)
|
|
1487
|
+
vi1 = self.known_vi_[node.output[1]]
|
|
1488
|
+
vi1.CopyFrom(helper.make_tensor_value_info(node.output[1], onnx.TensorProto.INT64, output_shape))
|
|
1489
|
+
|
|
1490
|
+
def _infer_aten_unfold(self, node):
|
|
1491
|
+
sympy_shape = self._get_sympy_shape(node, 0)
|
|
1492
|
+
dimension = self._try_get_value(node, 1)
|
|
1493
|
+
size = self._try_get_value(node, 2)
|
|
1494
|
+
step = self._try_get_value(node, 3)
|
|
1495
|
+
if dimension is not None and size is not None and step is not None:
|
|
1496
|
+
assert dimension < len(sympy_shape)
|
|
1497
|
+
sympy_shape[dimension] = (sympy_shape[dimension] - size) // step + 1
|
|
1498
|
+
sympy_shape.append(size)
|
|
1499
|
+
else:
|
|
1500
|
+
rank = len(sympy_shape)
|
|
1501
|
+
sympy_shape = self._new_symbolic_shape(rank + 1, node)
|
|
1502
|
+
self._update_computed_dims(sympy_shape)
|
|
1503
|
+
if node.output[0]:
|
|
1504
|
+
vi = self.known_vi_[node.output[0]]
|
|
1505
|
+
vi.CopyFrom(
|
|
1506
|
+
helper.make_tensor_value_info(
|
|
1507
|
+
node.output[0],
|
|
1508
|
+
self.known_vi_[node.input[0]].type.tensor_type.elem_type,
|
|
1509
|
+
get_shape_from_sympy_shape(sympy_shape),
|
|
1510
|
+
)
|
|
1511
|
+
)
|
|
1512
|
+
|
|
1513
|
+
def _infer_aten_argmax(self, node):
|
|
1514
|
+
new_shape = None
|
|
1515
|
+
if not node.input[1]:
|
|
1516
|
+
# The argmax of the flattened input is returned.
|
|
1517
|
+
new_shape = []
|
|
1518
|
+
else:
|
|
1519
|
+
dim = self._try_get_value(node, 1)
|
|
1520
|
+
keepdim = self._try_get_value(node, 2)
|
|
1521
|
+
if keepdim is not None:
|
|
1522
|
+
sympy_shape = self._get_sympy_shape(node, 0)
|
|
1523
|
+
if dim is not None:
|
|
1524
|
+
dim = handle_negative_axis(dim, len(sympy_shape))
|
|
1525
|
+
if keepdim:
|
|
1526
|
+
sympy_shape[dim] = 1
|
|
1527
|
+
else:
|
|
1528
|
+
del sympy_shape[dim]
|
|
1529
|
+
else:
|
|
1530
|
+
rank = len(sympy_shape)
|
|
1531
|
+
sympy_shape = self._new_symbolic_shape(rank if keepdim else rank - 1, node)
|
|
1532
|
+
self._update_computed_dims(sympy_shape)
|
|
1533
|
+
new_shape = get_shape_from_sympy_shape(sympy_shape)
|
|
1534
|
+
if node.output[0] and new_shape is not None:
|
|
1535
|
+
vi = self.known_vi_[node.output[0]]
|
|
1536
|
+
vi.CopyFrom(helper.make_tensor_value_info(node.output[0], onnx.TensorProto.INT64, new_shape))
|
|
1537
|
+
|
|
1538
|
+
def _infer_aten_group_norm(self, node):
|
|
1539
|
+
self._propagate_shape_and_type(node)
|
|
1540
|
+
input_shape = self._get_shape(node, 0)
|
|
1541
|
+
N = input_shape[0] if input_shape is not None and len(input_shape) != 0 else None # noqa: N806
|
|
1542
|
+
group = self._try_get_value(node, 6)
|
|
1543
|
+
output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type
|
|
1544
|
+
for i in [1, 2]:
|
|
1545
|
+
if node.output[i]:
|
|
1546
|
+
vi = self.known_vi_[node.output[i]]
|
|
1547
|
+
vi.CopyFrom(
|
|
1548
|
+
helper.make_tensor_value_info(
|
|
1549
|
+
node.output[i],
|
|
1550
|
+
output_dtype,
|
|
1551
|
+
[
|
|
1552
|
+
N if N is not None else str(self._new_symbolic_dim_from_output(node, i, 0)),
|
|
1553
|
+
(
|
|
1554
|
+
as_scalar(group)
|
|
1555
|
+
if group is not None
|
|
1556
|
+
else str(self._new_symbolic_dim_from_output(node, i, 1))
|
|
1557
|
+
),
|
|
1558
|
+
],
|
|
1559
|
+
)
|
|
1560
|
+
)
|
|
1561
|
+
|
|
1562
|
+
def _infer_aten_upsample(self, node):
|
|
1563
|
+
new_shape = None
|
|
1564
|
+
input_shape = self._get_shape(node, 0)
|
|
1565
|
+
if input_shape is not None:
|
|
1566
|
+
new_shape = input_shape[:2]
|
|
1567
|
+
output_size = self._try_get_value(node, 1)
|
|
1568
|
+
if output_size is not None:
|
|
1569
|
+
new_shape += [dim_size.item() if type(dim_size) is np.int64 else dim_size for dim_size in output_size]
|
|
1570
|
+
else:
|
|
1571
|
+
rank = len(input_shape)
|
|
1572
|
+
new_shape += [str(self._new_symbolic_dim_from_output(node, 0, i)) for i in range(2, rank)]
|
|
1573
|
+
if node.output[0] and new_shape is not None:
|
|
1574
|
+
output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type
|
|
1575
|
+
vi = self.known_vi_[node.output[0]]
|
|
1576
|
+
vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, new_shape))
|
|
1577
|
+
|
|
1578
|
+
def _infer_BatchNormalization(self, node): # noqa: N802
|
|
1579
|
+
self._propagate_shape_and_type(node)
|
|
1580
|
+
|
|
1581
|
+
# this works for opsets < 14 and 14 since we check i < len(node.output) in the loop
|
|
1582
|
+
for i in [1, 2, 3, 4]:
|
|
1583
|
+
if i < len(node.output) and node.output[i]:
|
|
1584
|
+
# all of these parameters have the same shape as the 1st input
|
|
1585
|
+
self._propagate_shape_and_type(node, input_index=1, output_index=i)
|
|
1586
|
+
|
|
1587
|
+
def _infer_Range(self, node): # noqa: N802
|
|
1588
|
+
vi = self.known_vi_[node.output[0]]
|
|
1589
|
+
input_data = self._get_int_or_float_values(node)
|
|
1590
|
+
if all(i is not None for i in input_data):
|
|
1591
|
+
start = as_scalar(input_data[0])
|
|
1592
|
+
limit = as_scalar(input_data[1])
|
|
1593
|
+
delta = as_scalar(input_data[2])
|
|
1594
|
+
new_sympy_shape = [sympy.Max(sympy.ceiling((limit - start) / delta), 0)]
|
|
1595
|
+
else:
|
|
1596
|
+
new_sympy_shape = [self._new_symbolic_dim_from_output(node)]
|
|
1597
|
+
self._update_computed_dims(new_sympy_shape)
|
|
1598
|
+
vi.CopyFrom(
|
|
1599
|
+
helper.make_tensor_value_info(
|
|
1600
|
+
node.output[0],
|
|
1601
|
+
self.known_vi_[node.input[0]].type.tensor_type.elem_type,
|
|
1602
|
+
get_shape_from_sympy_shape(new_sympy_shape),
|
|
1603
|
+
)
|
|
1604
|
+
)
|
|
1605
|
+
|
|
1606
|
+
def _infer_ReduceSum(self, node): # noqa: N802
|
|
1607
|
+
keep_dims = get_attribute(node, "keepdims", 1)
|
|
1608
|
+
if get_opset(self.out_mp_) >= 13 and len(node.input) > 1:
|
|
1609
|
+
# ReduceSum changes axes to input[1] in opset 13
|
|
1610
|
+
axes = self._try_get_value(node, 1)
|
|
1611
|
+
vi = self.known_vi_[node.output[0]]
|
|
1612
|
+
if axes is None:
|
|
1613
|
+
assert keep_dims # can only handle keep_dims==True when axes is unknown, by generating new ranks
|
|
1614
|
+
vi.CopyFrom(
|
|
1615
|
+
helper.make_tensor_value_info(
|
|
1616
|
+
node.output[0],
|
|
1617
|
+
self.known_vi_[node.input[0]].type.tensor_type.elem_type,
|
|
1618
|
+
get_shape_from_sympy_shape(self._new_symbolic_shape(self._get_shape_rank(node, 0), node)),
|
|
1619
|
+
)
|
|
1620
|
+
)
|
|
1621
|
+
else:
|
|
1622
|
+
shape = self._get_shape(node, 0)
|
|
1623
|
+
output_shape = []
|
|
1624
|
+
axes = [handle_negative_axis(a, len(shape)) for a in axes]
|
|
1625
|
+
for i, d in enumerate(shape):
|
|
1626
|
+
if i in axes:
|
|
1627
|
+
if keep_dims:
|
|
1628
|
+
output_shape.append(1)
|
|
1629
|
+
else:
|
|
1630
|
+
output_shape.append(d)
|
|
1631
|
+
vi.CopyFrom(
|
|
1632
|
+
helper.make_tensor_value_info(
|
|
1633
|
+
node.output[0],
|
|
1634
|
+
self.known_vi_[node.input[0]].type.tensor_type.elem_type,
|
|
1635
|
+
output_shape,
|
|
1636
|
+
)
|
|
1637
|
+
)
|
|
1638
|
+
|
|
1639
|
+
def _infer_ReduceMean(self, node): # noqa: N802
|
|
1640
|
+
if get_opset(self.out_mp_) >= 18:
|
|
1641
|
+
# reduce mean spec 18+ is same as reduce sum spec 13+
|
|
1642
|
+
self._infer_ReduceSum(node)
|
|
1643
|
+
|
|
1644
|
+
def _infer_ReduceProd(self, node): # noqa: N802
|
|
1645
|
+
axes = get_attribute(node, "axes")
|
|
1646
|
+
keep_dims = get_attribute(node, "keepdims", 1)
|
|
1647
|
+
if keep_dims == 0 and axes == [0]:
|
|
1648
|
+
data = self._get_int_or_float_values(node)[0]
|
|
1649
|
+
if data is not None:
|
|
1650
|
+
self.sympy_data_[node.output[0]] = sympy_reduce_product(data)
|
|
1651
|
+
|
|
1652
|
+
def _infer_RelativePositionBias(self, node): # noqa: N802
|
|
1653
|
+
seq_len = self._try_get_value(node, 1)
|
|
1654
|
+
real_seq_len = self._try_get_value(node, 2)
|
|
1655
|
+
if seq_len is None or real_seq_len is None:
|
|
1656
|
+
return
|
|
1657
|
+
num_heads = self._get_sympy_shape(node, 0)[1]
|
|
1658
|
+
|
|
1659
|
+
new_shape = [1, num_heads, str(seq_len), str(real_seq_len)]
|
|
1660
|
+
|
|
1661
|
+
output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type
|
|
1662
|
+
vi = self.known_vi_[node.output[0]]
|
|
1663
|
+
vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, new_shape))
|
|
1664
|
+
|
|
1665
|
+
def _infer_Reshape(self, node): # noqa: N802
|
|
1666
|
+
shape_value = self._try_get_value(node, 1)
|
|
1667
|
+
vi = self.known_vi_[node.output[0]]
|
|
1668
|
+
if shape_value is None:
|
|
1669
|
+
shape_shape = self._get_shape(node, 1)
|
|
1670
|
+
assert len(shape_shape) == 1
|
|
1671
|
+
shape_rank = shape_shape[0]
|
|
1672
|
+
assert is_literal(shape_rank)
|
|
1673
|
+
vi.CopyFrom(
|
|
1674
|
+
helper.make_tensor_value_info(
|
|
1675
|
+
node.output[0],
|
|
1676
|
+
vi.type.tensor_type.elem_type,
|
|
1677
|
+
get_shape_from_sympy_shape(self._new_symbolic_shape(shape_rank, node)),
|
|
1678
|
+
)
|
|
1679
|
+
)
|
|
1680
|
+
else:
|
|
1681
|
+
input_sympy_shape = self._get_sympy_shape(node, 0)
|
|
1682
|
+
total = 1
|
|
1683
|
+
for d in input_sympy_shape:
|
|
1684
|
+
total = total * d
|
|
1685
|
+
new_sympy_shape = []
|
|
1686
|
+
deferred_dim_idx = -1
|
|
1687
|
+
non_deferred_size = 1
|
|
1688
|
+
for i, d in enumerate(shape_value):
|
|
1689
|
+
if type(d) is sympy.Symbol:
|
|
1690
|
+
new_sympy_shape.append(d)
|
|
1691
|
+
elif d == 0:
|
|
1692
|
+
new_sympy_shape.append(input_sympy_shape[i])
|
|
1693
|
+
non_deferred_size = non_deferred_size * input_sympy_shape[i]
|
|
1694
|
+
else:
|
|
1695
|
+
new_sympy_shape.append(d)
|
|
1696
|
+
if d == -1:
|
|
1697
|
+
deferred_dim_idx = i
|
|
1698
|
+
elif d != 0:
|
|
1699
|
+
non_deferred_size = non_deferred_size * d
|
|
1700
|
+
|
|
1701
|
+
assert new_sympy_shape.count(-1) < 2
|
|
1702
|
+
if -1 in new_sympy_shape:
|
|
1703
|
+
new_dim = total // non_deferred_size
|
|
1704
|
+
new_sympy_shape[deferred_dim_idx] = new_dim
|
|
1705
|
+
|
|
1706
|
+
self._update_computed_dims(new_sympy_shape)
|
|
1707
|
+
vi.CopyFrom(
|
|
1708
|
+
helper.make_tensor_value_info(
|
|
1709
|
+
node.output[0],
|
|
1710
|
+
vi.type.tensor_type.elem_type,
|
|
1711
|
+
get_shape_from_sympy_shape(new_sympy_shape),
|
|
1712
|
+
)
|
|
1713
|
+
)
|
|
1714
|
+
|
|
1715
|
+
self._pass_on_sympy_data(node)
|
|
1716
|
+
|
|
1717
|
+
def _infer_Resize(self, node): # noqa: N802
|
|
1718
|
+
vi = self.known_vi_[node.output[0]]
|
|
1719
|
+
input_sympy_shape = self._get_sympy_shape(node, 0)
|
|
1720
|
+
if get_opset(self.out_mp_) <= 10:
|
|
1721
|
+
scales = self._try_get_value(node, 1)
|
|
1722
|
+
if scales is not None:
|
|
1723
|
+
new_sympy_shape = [
|
|
1724
|
+
sympy.simplify(sympy.floor(d * s)) for d, s in zip(input_sympy_shape, scales, strict=False)
|
|
1725
|
+
]
|
|
1726
|
+
self._update_computed_dims(new_sympy_shape)
|
|
1727
|
+
vi.CopyFrom(
|
|
1728
|
+
helper.make_tensor_value_info(
|
|
1729
|
+
node.output[0],
|
|
1730
|
+
self.known_vi_[node.input[0]].type.tensor_type.elem_type,
|
|
1731
|
+
get_shape_from_sympy_shape(new_sympy_shape),
|
|
1732
|
+
)
|
|
1733
|
+
)
|
|
1734
|
+
else:
|
|
1735
|
+
roi = self._try_get_value(node, 1)
|
|
1736
|
+
scales = self._try_get_value(node, 2)
|
|
1737
|
+
sizes = self._try_get_value(node, 3)
|
|
1738
|
+
if sizes is not None:
|
|
1739
|
+
new_sympy_shape = [sympy.simplify(sympy.floor(s)) for s in sizes]
|
|
1740
|
+
self._update_computed_dims(new_sympy_shape)
|
|
1741
|
+
elif scales is not None:
|
|
1742
|
+
rank = len(scales)
|
|
1743
|
+
if get_attribute(node, "coordinate_transformation_mode") == "tf_crop_and_resize":
|
|
1744
|
+
assert len(roi) == 2 * rank
|
|
1745
|
+
roi_start = list(roi)[:rank]
|
|
1746
|
+
roi_end = list(roi)[rank:]
|
|
1747
|
+
else:
|
|
1748
|
+
roi_start = [0] * rank
|
|
1749
|
+
roi_end = [1] * rank
|
|
1750
|
+
scales = list(scales)
|
|
1751
|
+
new_sympy_shape = [
|
|
1752
|
+
sympy.simplify(sympy.floor(d * (end - start) * scale))
|
|
1753
|
+
for d, start, end, scale in zip(input_sympy_shape, roi_start, roi_end, scales, strict=False)
|
|
1754
|
+
]
|
|
1755
|
+
self._update_computed_dims(new_sympy_shape)
|
|
1756
|
+
else:
|
|
1757
|
+
new_sympy_shape = self._new_symbolic_shape(self._get_shape_rank(node, 0), node)
|
|
1758
|
+
|
|
1759
|
+
vi.CopyFrom(
|
|
1760
|
+
helper.make_tensor_value_info(
|
|
1761
|
+
node.output[0],
|
|
1762
|
+
self.known_vi_[node.input[0]].type.tensor_type.elem_type,
|
|
1763
|
+
get_shape_from_sympy_shape(new_sympy_shape),
|
|
1764
|
+
)
|
|
1765
|
+
)
|
|
1766
|
+
|
|
1767
|
+
def _infer_Scan(self, node): # noqa: N802
|
|
1768
|
+
subgraph = get_attribute(node, "body")
|
|
1769
|
+
num_scan_inputs = get_attribute(node, "num_scan_inputs")
|
|
1770
|
+
scan_input_axes = get_attribute(node, "scan_input_axes", [0] * num_scan_inputs)
|
|
1771
|
+
num_scan_states = len(node.input) - num_scan_inputs
|
|
1772
|
+
scan_input_axes = [
|
|
1773
|
+
handle_negative_axis(ax, self._get_shape_rank(node, i + num_scan_states))
|
|
1774
|
+
for i, ax in enumerate(scan_input_axes)
|
|
1775
|
+
]
|
|
1776
|
+
# We may have cases where the subgraph has optional inputs that appear in both subgraph's input and initializer,
|
|
1777
|
+
# but not in the node's input. In such cases, the input model might be invalid, but let's skip those optional inputs.
|
|
1778
|
+
assert len(subgraph.input) >= len(node.input)
|
|
1779
|
+
subgraph_inputs = subgraph.input[: len(node.input)]
|
|
1780
|
+
for i, si in enumerate(subgraph_inputs):
|
|
1781
|
+
subgraph_name = si.name
|
|
1782
|
+
si.CopyFrom(self.known_vi_[node.input[i]])
|
|
1783
|
+
if i >= num_scan_states:
|
|
1784
|
+
scan_input_dim = si.type.tensor_type.shape.dim
|
|
1785
|
+
scan_input_dim.remove(scan_input_dim[scan_input_axes[i - num_scan_states]])
|
|
1786
|
+
si.name = subgraph_name
|
|
1787
|
+
self._onnx_infer_subgraph(node, subgraph)
|
|
1788
|
+
num_scan_outputs = len(node.output) - num_scan_states
|
|
1789
|
+
scan_output_axes = get_attribute(node, "scan_output_axes", [0] * num_scan_outputs)
|
|
1790
|
+
scan_input_dim = get_shape_from_type_proto(self.known_vi_[node.input[-1]].type)[scan_input_axes[-1]]
|
|
1791
|
+
for i, o in enumerate(node.output):
|
|
1792
|
+
vi = self.known_vi_[o]
|
|
1793
|
+
if i >= num_scan_states:
|
|
1794
|
+
shape = get_shape_from_type_proto(subgraph.output[i].type)
|
|
1795
|
+
new_dim = handle_negative_axis(scan_output_axes[i - num_scan_states], len(shape) + 1)
|
|
1796
|
+
shape = [*shape[:new_dim], scan_input_dim, *shape[new_dim:]]
|
|
1797
|
+
vi.CopyFrom(helper.make_tensor_value_info(o, subgraph.output[i].type.tensor_type.elem_type, shape))
|
|
1798
|
+
else:
|
|
1799
|
+
vi.CopyFrom(subgraph.output[i])
|
|
1800
|
+
vi.name = o
|
|
1801
|
+
|
|
1802
|
+
def _infer_ScatterElements(self, node): # noqa: N802
|
|
1803
|
+
data_shape = self._get_shape(node, 0)
|
|
1804
|
+
vi = self.known_vi_[node.output[0]]
|
|
1805
|
+
vi.CopyFrom(
|
|
1806
|
+
helper.make_tensor_value_info(
|
|
1807
|
+
node.output[0],
|
|
1808
|
+
self.known_vi_[node.input[0]].type.tensor_type.elem_type,
|
|
1809
|
+
data_shape,
|
|
1810
|
+
)
|
|
1811
|
+
)
|
|
1812
|
+
|
|
1813
|
+
def _infer_SequenceAt(self, node): # noqa: N802
|
|
1814
|
+
# need to create new symbolic dimension if sequence shape has None:
|
|
1815
|
+
seq_shape = self._get_shape(node, 0)
|
|
1816
|
+
vi = self.known_vi_[node.output[0]]
|
|
1817
|
+
if seq_shape is not None:
|
|
1818
|
+
for di, d in enumerate(seq_shape):
|
|
1819
|
+
if d is not None:
|
|
1820
|
+
continue
|
|
1821
|
+
new_dim = onnx.TensorShapeProto.Dimension()
|
|
1822
|
+
new_dim.dim_param = str(self._new_symbolic_dim_from_output(node, 0, di))
|
|
1823
|
+
vi.type.tensor_type.shape.dim[di].CopyFrom(new_dim)
|
|
1824
|
+
|
|
1825
|
+
def _infer_SequenceInsert(self, node): # noqa: N802
|
|
1826
|
+
# workaround bug in onnx's shape inference
|
|
1827
|
+
vi_seq = self.known_vi_[node.input[0]]
|
|
1828
|
+
vi_tensor = self.known_vi_[node.input[1]]
|
|
1829
|
+
vi_out_seq = self.known_vi_[node.output[0]]
|
|
1830
|
+
vi_out_seq.CopyFrom(vi_seq)
|
|
1831
|
+
vi_out_seq.name = node.output[0]
|
|
1832
|
+
self._fuse_tensor_type(node, 0, vi_out_seq.type, vi_tensor.type)
|
|
1833
|
+
|
|
1834
|
+
def _infer_Shape(self, node): # noqa: N802
|
|
1835
|
+
self.sympy_data_[node.output[0]] = self._get_sympy_shape(node, 0)
|
|
1836
|
+
|
|
1837
|
+
def _infer_Size(self, node): # noqa: N802
|
|
1838
|
+
sympy_shape = self._get_sympy_shape(node, 0)
|
|
1839
|
+
self.sympy_data_[node.output[0]] = sympy_reduce_product(sympy_shape)
|
|
1840
|
+
self.known_vi_[node.output[0]].CopyFrom(
|
|
1841
|
+
helper.make_tensor_value_info(node.output[0], onnx.TensorProto.INT64, [])
|
|
1842
|
+
)
|
|
1843
|
+
|
|
1844
|
+
def _infer_Slice(self, node): # noqa: N802
|
|
1845
|
+
# SymPy fails to prove that `x_0 + ... + x_n >= 0` if one of `x_i` is a `sympy.Min(a, b)`,
|
|
1846
|
+
# even when the relation holds for both `a` and `b`.
|
|
1847
|
+
#
|
|
1848
|
+
# When given `expr` of form `min(a, b) + ...`, this function returns `[a + ..., b + ...]`,
|
|
1849
|
+
# so that we can prove inequalities for both expressions separately.
|
|
1850
|
+
#
|
|
1851
|
+
# If the number of `min(...)` subexpressions is not exactly one, this function just returns `[expr]`.
|
|
1852
|
+
def flatten_min(expr):
|
|
1853
|
+
assert isinstance(expr, sympy.Add), f"Expected a sum of two arguments, got {expr}"
|
|
1854
|
+
min_positions = [idx for idx in range(len(expr.args)) if isinstance(expr.args[idx], sympy.Min)]
|
|
1855
|
+
if len(min_positions) == 1:
|
|
1856
|
+
min_pos = min_positions[0]
|
|
1857
|
+
|
|
1858
|
+
def replace_min_with_arg(arg_idx):
|
|
1859
|
+
replaced = list(expr.args)
|
|
1860
|
+
assert isinstance(replaced[min_pos], sympy.Min), (
|
|
1861
|
+
f"Expected a sympy.Min() at position {min_pos}, got {replaced[min_pos]}"
|
|
1862
|
+
)
|
|
1863
|
+
assert len(replaced[min_pos].args) == 2, (
|
|
1864
|
+
f"Expected a sympy.Min() with exactly 2 arguments, got {replaced[min_pos]}"
|
|
1865
|
+
)
|
|
1866
|
+
replaced[min_pos] = replaced[min_pos].args[arg_idx]
|
|
1867
|
+
return sympy.Add(*replaced)
|
|
1868
|
+
|
|
1869
|
+
return [
|
|
1870
|
+
replace_min_with_arg(0),
|
|
1871
|
+
replace_min_with_arg(1),
|
|
1872
|
+
]
|
|
1873
|
+
return [expr]
|
|
1874
|
+
|
|
1875
|
+
def less_equal(x, y):
|
|
1876
|
+
try:
|
|
1877
|
+
return bool(x <= y)
|
|
1878
|
+
except TypeError:
|
|
1879
|
+
pass
|
|
1880
|
+
try:
|
|
1881
|
+
return bool(y >= x)
|
|
1882
|
+
except TypeError:
|
|
1883
|
+
pass
|
|
1884
|
+
try:
|
|
1885
|
+
return bool(-x >= -y)
|
|
1886
|
+
except TypeError:
|
|
1887
|
+
pass
|
|
1888
|
+
try:
|
|
1889
|
+
return bool(-y <= -x)
|
|
1890
|
+
except TypeError:
|
|
1891
|
+
pass
|
|
1892
|
+
try:
|
|
1893
|
+
return bool(y - x >= 0)
|
|
1894
|
+
except TypeError:
|
|
1895
|
+
# the last attempt; this may raise TypeError
|
|
1896
|
+
return all(bool(d >= 0) for d in flatten_min(y - x))
|
|
1897
|
+
|
|
1898
|
+
def handle_negative_index(index, bound):
|
|
1899
|
+
"""normalizes a negative index to be in [0, bound)"""
|
|
1900
|
+
try:
|
|
1901
|
+
if not less_equal(0, index):
|
|
1902
|
+
if is_literal(index) and index <= -self.int_max_:
|
|
1903
|
+
# this case is handled separately
|
|
1904
|
+
return index
|
|
1905
|
+
return bound + index
|
|
1906
|
+
except TypeError:
|
|
1907
|
+
logger.warning(f"Cannot determine if {index} < 0")
|
|
1908
|
+
return index
|
|
1909
|
+
|
|
1910
|
+
if get_opset(self.out_mp_) <= 9:
|
|
1911
|
+
axes = get_attribute(node, "axes")
|
|
1912
|
+
starts = get_attribute(node, "starts")
|
|
1913
|
+
ends = get_attribute(node, "ends")
|
|
1914
|
+
if not axes:
|
|
1915
|
+
axes = list(range(len(starts)))
|
|
1916
|
+
steps = [1] * len(axes)
|
|
1917
|
+
else:
|
|
1918
|
+
starts = as_list(self._try_get_value(node, 1), keep_none=True)
|
|
1919
|
+
ends = as_list(self._try_get_value(node, 2), keep_none=True)
|
|
1920
|
+
axes = self._try_get_value(node, 3)
|
|
1921
|
+
steps = self._try_get_value(node, 4)
|
|
1922
|
+
if axes is None and not (starts is None and ends is None):
|
|
1923
|
+
axes = list(range(len(starts if starts is not None else ends)))
|
|
1924
|
+
if steps is None and not (starts is None and ends is None):
|
|
1925
|
+
steps = [1] * len(starts if starts is not None else ends)
|
|
1926
|
+
axes = as_list(axes, keep_none=True)
|
|
1927
|
+
steps = as_list(steps, keep_none=True)
|
|
1928
|
+
|
|
1929
|
+
new_sympy_shape = self._get_sympy_shape(node, 0)
|
|
1930
|
+
if starts is None or ends is None:
|
|
1931
|
+
if axes is None:
|
|
1932
|
+
for i in range(len(new_sympy_shape)):
|
|
1933
|
+
new_sympy_shape[i] = self._new_symbolic_dim_from_output(node, 0, i)
|
|
1934
|
+
else:
|
|
1935
|
+
new_sympy_shape = get_shape_from_sympy_shape(new_sympy_shape)
|
|
1936
|
+
for i in axes:
|
|
1937
|
+
new_sympy_shape[i] = self._new_symbolic_dim_from_output(node, 0, i)
|
|
1938
|
+
else:
|
|
1939
|
+
for i, s, e, t in zip(axes, starts, ends, steps, strict=False):
|
|
1940
|
+
e = handle_negative_index(e, new_sympy_shape[i]) # noqa: PLW2901
|
|
1941
|
+
if is_literal(e):
|
|
1942
|
+
if e >= self.int_max_:
|
|
1943
|
+
e = new_sympy_shape[i] # noqa: PLW2901
|
|
1944
|
+
elif e <= -self.int_max_:
|
|
1945
|
+
e = 0 if s > 0 else -1 # noqa: PLW2901
|
|
1946
|
+
elif is_literal(new_sympy_shape[i]):
|
|
1947
|
+
if e < 0:
|
|
1948
|
+
e = max(0, e + new_sympy_shape[i]) # noqa: PLW2901
|
|
1949
|
+
e = min(e, new_sympy_shape[i]) # noqa: PLW2901
|
|
1950
|
+
else:
|
|
1951
|
+
if e > 0:
|
|
1952
|
+
e = ( # noqa: PLW2901
|
|
1953
|
+
sympy.Min(e, new_sympy_shape[i]) if e > 1 else e
|
|
1954
|
+
) # special case for slicing first to make computation easier
|
|
1955
|
+
else:
|
|
1956
|
+
if is_literal(new_sympy_shape[i]):
|
|
1957
|
+
e = sympy.Min(e, new_sympy_shape[i]) # noqa: PLW2901
|
|
1958
|
+
else:
|
|
1959
|
+
try:
|
|
1960
|
+
if not less_equal(e, new_sympy_shape[i]):
|
|
1961
|
+
e = new_sympy_shape[i] # noqa: PLW2901
|
|
1962
|
+
except Exception:
|
|
1963
|
+
logger.warning(f"Unable to determine if {e} <= {new_sympy_shape[i]}, treat as equal")
|
|
1964
|
+
e = new_sympy_shape[i] # noqa: PLW2901
|
|
1965
|
+
|
|
1966
|
+
s = handle_negative_index(s, new_sympy_shape[i]) # noqa: PLW2901
|
|
1967
|
+
if is_literal(new_sympy_shape[i]) and is_literal(s):
|
|
1968
|
+
s = max(0, min(s, new_sympy_shape[i])) # noqa: PLW2901
|
|
1969
|
+
|
|
1970
|
+
new_sympy_shape[i] = sympy.simplify((e - s + t + (-1 if t > 0 else 1)) // t)
|
|
1971
|
+
|
|
1972
|
+
self._update_computed_dims(new_sympy_shape)
|
|
1973
|
+
|
|
1974
|
+
vi = self.known_vi_[node.output[0]]
|
|
1975
|
+
vi.CopyFrom(
|
|
1976
|
+
helper.make_tensor_value_info(
|
|
1977
|
+
node.output[0],
|
|
1978
|
+
vi.type.tensor_type.elem_type,
|
|
1979
|
+
get_shape_from_sympy_shape(new_sympy_shape),
|
|
1980
|
+
)
|
|
1981
|
+
)
|
|
1982
|
+
|
|
1983
|
+
# handle sympy_data if needed, for slice in shape computation
|
|
1984
|
+
if (
|
|
1985
|
+
node.input[0] in self.sympy_data_
|
|
1986
|
+
and axes == [0]
|
|
1987
|
+
and starts is not None
|
|
1988
|
+
and len(starts) == 1
|
|
1989
|
+
and ends is not None
|
|
1990
|
+
and len(ends) == 1
|
|
1991
|
+
and steps is not None
|
|
1992
|
+
and len(steps) == 1
|
|
1993
|
+
):
|
|
1994
|
+
input_sympy_data = self.sympy_data_[node.input[0]]
|
|
1995
|
+
if type(input_sympy_data) is list or (
|
|
1996
|
+
type(input_sympy_data) is np.array and len(input_sympy_data.shape) == 1
|
|
1997
|
+
):
|
|
1998
|
+
self.sympy_data_[node.output[0]] = input_sympy_data[starts[0] : ends[0] : steps[0]]
|
|
1999
|
+
|
|
2000
|
+
def _infer_SoftmaxCrossEntropyLoss(self, node): # noqa: N802
|
|
2001
|
+
vi = self.known_vi_[node.output[0]]
|
|
2002
|
+
elem_type = self.known_vi_[node.input[0]].type.tensor_type.elem_type
|
|
2003
|
+
|
|
2004
|
+
# If output type is explicit specified in attribute, we use it as output tensor type.
|
|
2005
|
+
specified_output_type = get_attribute(node, "output_type", None)
|
|
2006
|
+
if specified_output_type is not None:
|
|
2007
|
+
elem_type = specified_output_type
|
|
2008
|
+
|
|
2009
|
+
vi.type.tensor_type.elem_type = elem_type
|
|
2010
|
+
vi.type.tensor_type.shape.CopyFrom(onnx.TensorShapeProto())
|
|
2011
|
+
|
|
2012
|
+
if len(node.output) > 1:
|
|
2013
|
+
data_shape = self._get_shape(node, 0)
|
|
2014
|
+
vi = self.known_vi_[node.output[1]]
|
|
2015
|
+
vi.CopyFrom(helper.make_tensor_value_info(vi.name, elem_type, data_shape))
|
|
2016
|
+
|
|
2017
|
+
def _infer_Split_Common(self, node, make_value_info_func): # noqa: N802
|
|
2018
|
+
input_sympy_shape = self._get_sympy_shape(node, 0)
|
|
2019
|
+
axis = handle_negative_axis(get_attribute(node, "axis", 0), len(input_sympy_shape))
|
|
2020
|
+
op_set = get_opset(self.out_mp_)
|
|
2021
|
+
|
|
2022
|
+
# Depending on op-version 'split' are provided as attribute or via 2nd input
|
|
2023
|
+
if op_set < 13:
|
|
2024
|
+
split = get_attribute(node, "split")
|
|
2025
|
+
assert self._try_get_value(node, 1) is None
|
|
2026
|
+
else:
|
|
2027
|
+
split = self._try_get_value(node, 1)
|
|
2028
|
+
assert get_attribute(node, "split") is None
|
|
2029
|
+
|
|
2030
|
+
if split is None:
|
|
2031
|
+
num_outputs = len(node.output)
|
|
2032
|
+
split = [input_sympy_shape[axis] / sympy.Integer(num_outputs)] * num_outputs
|
|
2033
|
+
self._update_computed_dims(split)
|
|
2034
|
+
else:
|
|
2035
|
+
split = [sympy.Integer(s) for s in split]
|
|
2036
|
+
|
|
2037
|
+
for i_o in range(len(split)):
|
|
2038
|
+
vi = self.known_vi_[node.output[i_o]]
|
|
2039
|
+
vi.CopyFrom(
|
|
2040
|
+
make_value_info_func(
|
|
2041
|
+
node.output[i_o],
|
|
2042
|
+
self.known_vi_[node.input[0]].type.tensor_type.elem_type,
|
|
2043
|
+
get_shape_from_sympy_shape([*input_sympy_shape[:axis], split[i_o], *input_sympy_shape[axis + 1 :]]),
|
|
2044
|
+
)
|
|
2045
|
+
)
|
|
2046
|
+
self.known_vi_[vi.name] = vi
|
|
2047
|
+
|
|
2048
|
+
def _infer_Split(self, node): # noqa: N802
|
|
2049
|
+
self._infer_Split_Common(node, helper.make_tensor_value_info)
|
|
2050
|
+
|
|
2051
|
+
def _infer_SplitToSequence(self, node): # noqa: N802
|
|
2052
|
+
self._infer_Split_Common(node, helper.make_sequence_value_info)
|
|
2053
|
+
|
|
2054
|
+
def _infer_Squeeze(self, node): # noqa: N802
|
|
2055
|
+
input_shape = self._get_shape(node, 0)
|
|
2056
|
+
op_set = get_opset(self.out_mp_)
|
|
2057
|
+
|
|
2058
|
+
# Depending on op-version 'axes' are provided as attribute or via 2nd input
|
|
2059
|
+
if op_set < 13:
|
|
2060
|
+
axes = get_attribute(node, "axes")
|
|
2061
|
+
assert self._try_get_value(node, 1) is None
|
|
2062
|
+
else:
|
|
2063
|
+
axes = self._try_get_value(node, 1)
|
|
2064
|
+
assert get_attribute(node, "axes") is None
|
|
2065
|
+
|
|
2066
|
+
if axes is None:
|
|
2067
|
+
# No axes have been provided (neither via attribute nor via input).
|
|
2068
|
+
# In this case the 'Shape' op should remove all axis with dimension 1.
|
|
2069
|
+
# For symbolic dimensions we guess they are !=1.
|
|
2070
|
+
output_shape = [s for s in input_shape if s != 1]
|
|
2071
|
+
if self.verbose_ > 0:
|
|
2072
|
+
symbolic_dimensions = [s for s in input_shape if type(s) != int] # noqa: E721
|
|
2073
|
+
if len(symbolic_dimensions) > 0:
|
|
2074
|
+
logger.debug(
|
|
2075
|
+
f"Symbolic dimensions in input shape of op: '{node.op_type}' node: '{node.name}'. "
|
|
2076
|
+
f"Assuming the following dimensions are never equal to 1: {symbolic_dimensions}"
|
|
2077
|
+
)
|
|
2078
|
+
else:
|
|
2079
|
+
axes = [handle_negative_axis(a, len(input_shape)) for a in axes]
|
|
2080
|
+
output_shape = []
|
|
2081
|
+
for i in range(len(input_shape)):
|
|
2082
|
+
if i not in axes:
|
|
2083
|
+
output_shape.append(input_shape[i])
|
|
2084
|
+
else:
|
|
2085
|
+
assert input_shape[i] == 1 or type(input_shape[i]) != int # noqa: E721
|
|
2086
|
+
if self.verbose_ > 0 and type(input_shape[i]) != int: # noqa: E721
|
|
2087
|
+
logger.debug(
|
|
2088
|
+
f"Symbolic dimensions in input shape of op: '{node.op_type}' node: '{node.name}'. "
|
|
2089
|
+
f"Assuming the dimension '{input_shape[i]}' at index {i} of the input to be equal to 1."
|
|
2090
|
+
)
|
|
2091
|
+
|
|
2092
|
+
vi = self.known_vi_[node.output[0]]
|
|
2093
|
+
vi.CopyFrom(
|
|
2094
|
+
helper.make_tensor_value_info(
|
|
2095
|
+
node.output[0],
|
|
2096
|
+
self.known_vi_[node.input[0]].type.tensor_type.elem_type,
|
|
2097
|
+
output_shape,
|
|
2098
|
+
)
|
|
2099
|
+
)
|
|
2100
|
+
self._pass_on_sympy_data(node)
|
|
2101
|
+
|
|
2102
|
+
def _infer_Tile(self, node): # noqa: N802
|
|
2103
|
+
repeats_value = self._try_get_value(node, 1)
|
|
2104
|
+
new_sympy_shape = []
|
|
2105
|
+
if repeats_value is not None:
|
|
2106
|
+
input_sympy_shape = self._get_sympy_shape(node, 0)
|
|
2107
|
+
for i, d in enumerate(input_sympy_shape):
|
|
2108
|
+
new_dim = d * repeats_value[i]
|
|
2109
|
+
new_sympy_shape.append(new_dim)
|
|
2110
|
+
self._update_computed_dims(new_sympy_shape)
|
|
2111
|
+
else:
|
|
2112
|
+
new_sympy_shape = self._new_symbolic_shape(self._get_shape_rank(node, 0), node)
|
|
2113
|
+
vi = self.known_vi_[node.output[0]]
|
|
2114
|
+
vi.CopyFrom(
|
|
2115
|
+
helper.make_tensor_value_info(
|
|
2116
|
+
node.output[0],
|
|
2117
|
+
vi.type.tensor_type.elem_type,
|
|
2118
|
+
get_shape_from_sympy_shape(new_sympy_shape),
|
|
2119
|
+
)
|
|
2120
|
+
)
|
|
2121
|
+
|
|
2122
|
+
def _infer_TopK(self, node): # noqa: N802
|
|
2123
|
+
rank = self._get_shape_rank(node, 0)
|
|
2124
|
+
axis = handle_negative_axis(get_attribute(node, "axis", -1), rank)
|
|
2125
|
+
new_shape = self._get_shape(node, 0)
|
|
2126
|
+
|
|
2127
|
+
if get_opset(self.out_mp_) <= 9:
|
|
2128
|
+
k = get_attribute(node, "k")
|
|
2129
|
+
else:
|
|
2130
|
+
k = self._get_int_or_float_values(node)[1]
|
|
2131
|
+
|
|
2132
|
+
if k is None:
|
|
2133
|
+
k = self._new_symbolic_dim_from_output(node)
|
|
2134
|
+
else:
|
|
2135
|
+
k = as_scalar(k)
|
|
2136
|
+
|
|
2137
|
+
if type(k) in [int, str]:
|
|
2138
|
+
new_shape[axis] = k
|
|
2139
|
+
else:
|
|
2140
|
+
new_sympy_shape = self._get_sympy_shape(node, 0)
|
|
2141
|
+
new_sympy_shape[axis] = k
|
|
2142
|
+
self._update_computed_dims(
|
|
2143
|
+
new_sympy_shape
|
|
2144
|
+
) # note that TopK dim could be computed in sympy_data, so need to update computed_dims when it enters shape
|
|
2145
|
+
new_shape = get_shape_from_sympy_shape(new_sympy_shape)
|
|
2146
|
+
|
|
2147
|
+
for i_o in range(len(node.output)):
|
|
2148
|
+
vi = self.known_vi_[node.output[i_o]]
|
|
2149
|
+
vi.CopyFrom(helper.make_tensor_value_info(node.output[i_o], vi.type.tensor_type.elem_type, new_shape))
|
|
2150
|
+
|
|
2151
|
+
def _infer_Transpose(self, node): # noqa: N802
|
|
2152
|
+
if node.input[0] in self.sympy_data_:
|
|
2153
|
+
data_shape = self._get_shape(node, 0)
|
|
2154
|
+
perm = get_attribute(node, "perm", reversed(list(range(len(data_shape)))))
|
|
2155
|
+
input_data = self.sympy_data_[node.input[0]]
|
|
2156
|
+
self.sympy_data_[node.output[0]] = (
|
|
2157
|
+
np.transpose(np.array(input_data).reshape(*data_shape), axes=tuple(perm)).flatten().tolist()
|
|
2158
|
+
)
|
|
2159
|
+
|
|
2160
|
+
def _infer_Unsqueeze(self, node): # noqa: N802
|
|
2161
|
+
input_shape = self._get_shape(node, 0)
|
|
2162
|
+
op_set = get_opset(self.out_mp_)
|
|
2163
|
+
|
|
2164
|
+
# Depending on op-version 'axes' are provided as attribute or via 2nd input
|
|
2165
|
+
if op_set < 13:
|
|
2166
|
+
axes = get_attribute(node, "axes")
|
|
2167
|
+
assert self._try_get_value(node, 1) is None
|
|
2168
|
+
else:
|
|
2169
|
+
axes = self._try_get_value(node, 1)
|
|
2170
|
+
assert get_attribute(node, "axes") is None
|
|
2171
|
+
|
|
2172
|
+
output_rank = len(input_shape) + len(axes)
|
|
2173
|
+
axes = [handle_negative_axis(a, output_rank) for a in axes]
|
|
2174
|
+
|
|
2175
|
+
input_axis = 0
|
|
2176
|
+
output_shape = []
|
|
2177
|
+
for i in range(output_rank):
|
|
2178
|
+
if i in axes:
|
|
2179
|
+
output_shape.append(1)
|
|
2180
|
+
else:
|
|
2181
|
+
output_shape.append(input_shape[input_axis])
|
|
2182
|
+
input_axis += 1
|
|
2183
|
+
|
|
2184
|
+
vi = self.known_vi_[node.output[0]]
|
|
2185
|
+
vi.CopyFrom(
|
|
2186
|
+
helper.make_tensor_value_info(
|
|
2187
|
+
node.output[0],
|
|
2188
|
+
self.known_vi_[node.input[0]].type.tensor_type.elem_type,
|
|
2189
|
+
output_shape,
|
|
2190
|
+
)
|
|
2191
|
+
)
|
|
2192
|
+
|
|
2193
|
+
self._pass_on_sympy_data(node)
|
|
2194
|
+
|
|
2195
|
+
def _infer_ZipMap(self, node): # noqa: N802
|
|
2196
|
+
map_key_type = None
|
|
2197
|
+
if get_attribute(node, "classlabels_int64s") is not None:
|
|
2198
|
+
map_key_type = onnx.TensorProto.INT64
|
|
2199
|
+
elif get_attribute(node, "classlabels_strings") is not None:
|
|
2200
|
+
map_key_type = onnx.TensorProto.STRING
|
|
2201
|
+
|
|
2202
|
+
assert map_key_type is not None
|
|
2203
|
+
new_vi = onnx.ValueInfoProto()
|
|
2204
|
+
new_vi.name = node.output[0]
|
|
2205
|
+
new_vi.type.sequence_type.elem_type.map_type.value_type.tensor_type.elem_type = onnx.TensorProto.FLOAT
|
|
2206
|
+
new_vi.type.sequence_type.elem_type.map_type.key_type = map_key_type
|
|
2207
|
+
vi = self.known_vi_[node.output[0]]
|
|
2208
|
+
vi.CopyFrom(new_vi)
|
|
2209
|
+
|
|
2210
|
+
def _infer_Attention(self, node): # noqa: N802
|
|
2211
|
+
shape = self._get_shape(node, 0)
|
|
2212
|
+
shape_weights = self._get_shape(node, 1)
|
|
2213
|
+
shape_bias = self._try_get_shape(node, 2)
|
|
2214
|
+
if shape_bias is not None:
|
|
2215
|
+
assert len(shape_bias) == 1
|
|
2216
|
+
tripled_hidden_size = shape_bias[0] if shape_bias is not None else shape_weights[1]
|
|
2217
|
+
if shape and len(shape) == 3:
|
|
2218
|
+
qkv_hidden_sizes_attr = get_attribute(node, "qkv_hidden_sizes")
|
|
2219
|
+
if qkv_hidden_sizes_attr is not None:
|
|
2220
|
+
assert len(qkv_hidden_sizes_attr) == 3
|
|
2221
|
+
shape[2] = int(qkv_hidden_sizes_attr[2])
|
|
2222
|
+
elif isinstance(tripled_hidden_size, int):
|
|
2223
|
+
shape[2] = int(tripled_hidden_size / 3)
|
|
2224
|
+
output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type
|
|
2225
|
+
vi = self.known_vi_[node.output[0]]
|
|
2226
|
+
vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, shape))
|
|
2227
|
+
|
|
2228
|
+
if len(node.output) > 1:
|
|
2229
|
+
# input shape: (batch_size, sequence_length, hidden_size)
|
|
2230
|
+
# past shape: (2, batch_size, num_heads, past_sequence_length, head_size)
|
|
2231
|
+
# mask shape: (batch_size, total_sequence_length) or (batch_size, sequence_length, total_sequence_length) or (batch_size, 1, max_seq_len, max_seq_len)
|
|
2232
|
+
# present shape: (2, batch_size, num_heads, total_sequence_length, head_size), where total_sequence_length=sequence_length+past_sequence_length
|
|
2233
|
+
input_shape = self._get_shape(node, 0)
|
|
2234
|
+
past_shape = self._get_shape(node, 4) if len(node.input) > 4 and node.input[4] else []
|
|
2235
|
+
mask_shape = self._get_shape(node, 3) if len(node.input) > 3 and node.input[3] else []
|
|
2236
|
+
|
|
2237
|
+
if past_shape and len(past_shape) == 5:
|
|
2238
|
+
if mask_shape and len(mask_shape) in [2, 3]:
|
|
2239
|
+
past_shape[3] = mask_shape[-1]
|
|
2240
|
+
elif input_shape and len(input_shape) == 3:
|
|
2241
|
+
if isinstance(input_shape[1], int) and isinstance(past_shape[3], int):
|
|
2242
|
+
past_shape[3] = input_shape[1] + past_shape[3]
|
|
2243
|
+
else:
|
|
2244
|
+
past_shape[3] = f"{past_shape[3]}+{input_shape[1]}"
|
|
2245
|
+
vi = self.known_vi_[node.output[1]]
|
|
2246
|
+
vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, past_shape))
|
|
2247
|
+
# No past input but present output still exists
|
|
2248
|
+
else:
|
|
2249
|
+
num_heads = get_attribute(node, "num_heads")
|
|
2250
|
+
head_size = input_shape[2] // num_heads
|
|
2251
|
+
present_shape = [2, input_shape[0], num_heads, input_shape[1], head_size]
|
|
2252
|
+
vi = self.known_vi_[node.output[1]]
|
|
2253
|
+
vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, present_shape))
|
|
2254
|
+
|
|
2255
|
+
def _infer_GatedRelativePositionBias(self, node): # noqa: N802
|
|
2256
|
+
# When padding is removed:
|
|
2257
|
+
# query_layer: (token_count, num_heads x head_size)
|
|
2258
|
+
# token_offset: (batch_size, seq_len)
|
|
2259
|
+
# Otherwise:
|
|
2260
|
+
# query_layer: (batch_size, seq_len, num_heads x head_size)
|
|
2261
|
+
# token_offset: None
|
|
2262
|
+
# Output shape: (batch_size, num_heads, seq_len, seq_len)
|
|
2263
|
+
num_heads = get_attribute(node, "num_heads")
|
|
2264
|
+
|
|
2265
|
+
token_offset_shape = self._try_get_shape(node, 6)
|
|
2266
|
+
if token_offset_shape is not None:
|
|
2267
|
+
output_shape = [token_offset_shape[0], num_heads, token_offset_shape[1], token_offset_shape[1]]
|
|
2268
|
+
else:
|
|
2269
|
+
query_layer_shape = self._get_shape(node, 0)
|
|
2270
|
+
assert query_layer_shape is not None and len(query_layer_shape) == 3
|
|
2271
|
+
output_shape = [query_layer_shape[0], num_heads, query_layer_shape[1], query_layer_shape[1]]
|
|
2272
|
+
|
|
2273
|
+
output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type
|
|
2274
|
+
vi = self.known_vi_[node.output[0]]
|
|
2275
|
+
vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape))
|
|
2276
|
+
|
|
2277
|
+
def _infer_PackedAttention(self, node): # noqa: N802
|
|
2278
|
+
shape = self._get_shape(node, 0)
|
|
2279
|
+
shape_weights = self._get_shape(node, 1)
|
|
2280
|
+
shape_bias = self._try_get_shape(node, 2)
|
|
2281
|
+
if shape_bias is not None:
|
|
2282
|
+
assert len(shape_bias) == 1
|
|
2283
|
+
tripled_hidden_size = shape_bias[0] if shape_bias is not None else shape_weights[1]
|
|
2284
|
+
if shape and len(shape) == 2:
|
|
2285
|
+
qkv_hidden_sizes_attr = get_attribute(node, "qkv_hidden_sizes")
|
|
2286
|
+
if qkv_hidden_sizes_attr is not None:
|
|
2287
|
+
assert len(qkv_hidden_sizes_attr) == 3
|
|
2288
|
+
shape[1] = int(qkv_hidden_sizes_attr[2])
|
|
2289
|
+
elif isinstance(tripled_hidden_size, int):
|
|
2290
|
+
shape[1] = int(tripled_hidden_size / 3)
|
|
2291
|
+
output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type
|
|
2292
|
+
vi = self.known_vi_[node.output[0]]
|
|
2293
|
+
vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, shape))
|
|
2294
|
+
|
|
2295
|
+
def _infer_PackedMultiHeadAttention(self, node): # noqa: N802
|
|
2296
|
+
shape_value = self._try_get_shape(node, 2)
|
|
2297
|
+
if shape_value is not None and len(shape_value) == 2:
|
|
2298
|
+
output_shape = shape_value
|
|
2299
|
+
else:
|
|
2300
|
+
shape_query = self._get_shape(node, 0)
|
|
2301
|
+
assert shape_query is not None and len(shape_query) == 4
|
|
2302
|
+
output_shape = [shape_query[0], shape_query[1] * shape_query[3]]
|
|
2303
|
+
|
|
2304
|
+
output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type
|
|
2305
|
+
vi = self.known_vi_[node.output[0]]
|
|
2306
|
+
vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape))
|
|
2307
|
+
|
|
2308
|
+
def _infer_RemovePadding(self, node): # noqa: N802
|
|
2309
|
+
shape = self._get_shape(node, 0)
|
|
2310
|
+
if shape and len(shape) == 3:
|
|
2311
|
+
output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type
|
|
2312
|
+
vi = self.known_vi_[node.output[0]]
|
|
2313
|
+
vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, ["token_count", shape[2]]))
|
|
2314
|
+
|
|
2315
|
+
vi_token_offset = self.known_vi_[node.output[1]]
|
|
2316
|
+
vi_token_offset.CopyFrom(
|
|
2317
|
+
helper.make_tensor_value_info(node.output[1], onnx.TensorProto.INT32, [shape[0], shape[1]])
|
|
2318
|
+
)
|
|
2319
|
+
|
|
2320
|
+
vi_cumulated_seq_len = self.known_vi_[node.output[2]]
|
|
2321
|
+
vi_cumulated_seq_len.CopyFrom(
|
|
2322
|
+
helper.make_tensor_value_info(node.output[2], onnx.TensorProto.INT32, ["batch_size + 1"])
|
|
2323
|
+
)
|
|
2324
|
+
|
|
2325
|
+
vi_max_seq_len = self.known_vi_[node.output[3]]
|
|
2326
|
+
vi_max_seq_len.CopyFrom(helper.make_tensor_value_info(node.output[3], onnx.TensorProto.INT32, [1]))
|
|
2327
|
+
|
|
2328
|
+
def _infer_RestorePadding(self, node): # noqa: N802
|
|
2329
|
+
shape_input = self._get_shape(node, 0)
|
|
2330
|
+
shape_token_offset = self._get_shape(node, 1)
|
|
2331
|
+
if shape_input and len(shape_input) == 2 and shape_token_offset and len(shape_token_offset) == 2:
|
|
2332
|
+
output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type
|
|
2333
|
+
vi = self.known_vi_[node.output[0]]
|
|
2334
|
+
|
|
2335
|
+
output_shape = [shape_token_offset[0], shape_token_offset[1], shape_input[1]]
|
|
2336
|
+
vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape))
|
|
2337
|
+
|
|
2338
|
+
def _infer_BiasGelu(self, node): # noqa: N802
|
|
2339
|
+
self._propagate_shape_and_type(node)
|
|
2340
|
+
|
|
2341
|
+
def _infer_MultiHeadAttention(self, node): # noqa: N802
|
|
2342
|
+
# Output 0 has shape (batch_size, sequence_length, v_hidden_size)
|
|
2343
|
+
# Q, K and V without packing:
|
|
2344
|
+
# Input 0 (query) has shape (batch_size, sequence_length, hidden_size)
|
|
2345
|
+
# Input 1 (key) has shape (batch_size, kv_sequence_length, hidden_size) or (batch_size, num_heads, kv_sequence_length, head_size)
|
|
2346
|
+
# Input 2 (value) has shape (batch_size, kv_sequence_length, v_hidden_size) or (batch_size, num_heads, kv_sequence_length, head_size)
|
|
2347
|
+
# Packed KV:
|
|
2348
|
+
# Input 0 (query) has shape (batch_size, sequence_length, hidden_size)
|
|
2349
|
+
# Input 1 (batch_size, kv_sequence_length, num_heads, 2, head_size)
|
|
2350
|
+
# Input 2 nullptr
|
|
2351
|
+
# Packed QKV:
|
|
2352
|
+
# Input 0 (batch_size, sequence_length, num_heads, 3, head_size)
|
|
2353
|
+
# Input 1 nullptr
|
|
2354
|
+
# Input 2 nullptr
|
|
2355
|
+
|
|
2356
|
+
query_shape = self._get_shape(node, 0)
|
|
2357
|
+
total_sequence_length = None
|
|
2358
|
+
output_dtype = None
|
|
2359
|
+
if query_shape is not None:
|
|
2360
|
+
if len(query_shape) == 3:
|
|
2361
|
+
key_shape = self._try_get_shape(node, 1)
|
|
2362
|
+
# By default, hidden size is same for Q/K/V. Only need check v_hidden_size when value is provided.
|
|
2363
|
+
output_shape = query_shape
|
|
2364
|
+
if key_shape is not None and len(key_shape) == 3:
|
|
2365
|
+
value_shape = self._try_get_shape(node, 2)
|
|
2366
|
+
if value_shape is not None and len(value_shape) == 3:
|
|
2367
|
+
output_shape[2] = value_shape[2]
|
|
2368
|
+
total_sequence_length = key_shape[1]
|
|
2369
|
+
|
|
2370
|
+
output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type
|
|
2371
|
+
vi = self.known_vi_[node.output[0]]
|
|
2372
|
+
vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape))
|
|
2373
|
+
|
|
2374
|
+
elif len(query_shape) == 5:
|
|
2375
|
+
if isinstance(query_shape[2], int) and isinstance(query_shape[4], int):
|
|
2376
|
+
output_shape = [query_shape[0], query_shape[1], query_shape[2] * query_shape[4]]
|
|
2377
|
+
else:
|
|
2378
|
+
output_shape = [query_shape[0], query_shape[1], f"{query_shape[2]}*{query_shape[4]}"]
|
|
2379
|
+
|
|
2380
|
+
total_sequence_length = query_shape[1]
|
|
2381
|
+
|
|
2382
|
+
output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type
|
|
2383
|
+
vi = self.known_vi_[node.output[0]]
|
|
2384
|
+
vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape))
|
|
2385
|
+
|
|
2386
|
+
if len(node.output) > 1:
|
|
2387
|
+
batch_size = query_shape[0]
|
|
2388
|
+
num_heads = get_attribute(node, "num_heads")
|
|
2389
|
+
|
|
2390
|
+
head_size = None
|
|
2391
|
+
if len(query_shape) == 3:
|
|
2392
|
+
head_size = (
|
|
2393
|
+
int(query_shape[2] / num_heads)
|
|
2394
|
+
if isinstance(query_shape[2], int)
|
|
2395
|
+
else f"{query_shape[2]}/{num_heads}"
|
|
2396
|
+
)
|
|
2397
|
+
else:
|
|
2398
|
+
head_size = query_shape[4]
|
|
2399
|
+
|
|
2400
|
+
past_shape = self._try_get_shape(node, 6)
|
|
2401
|
+
|
|
2402
|
+
if past_shape is not None:
|
|
2403
|
+
if isinstance(past_shape[2], int) and isinstance(total_sequence_length, int):
|
|
2404
|
+
total_sequence_length = past_shape[2] + total_sequence_length
|
|
2405
|
+
else:
|
|
2406
|
+
total_sequence_length = f"{past_shape[2]}+{total_sequence_length}"
|
|
2407
|
+
|
|
2408
|
+
present_shape = [batch_size, num_heads, total_sequence_length, head_size]
|
|
2409
|
+
|
|
2410
|
+
assert output_dtype is not None
|
|
2411
|
+
if len(node.output) > 2 and node.output[1] and node.output[2]:
|
|
2412
|
+
vi = self.known_vi_[node.output[1]]
|
|
2413
|
+
vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, present_shape))
|
|
2414
|
+
vi = self.known_vi_[node.output[2]]
|
|
2415
|
+
vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, present_shape))
|
|
2416
|
+
|
|
2417
|
+
def _infer_DecoderMaskedMultiHeadAttention(self, node): # noqa: N802
|
|
2418
|
+
# Output 0 has shape (batch_size, 1, v_hidden_size)
|
|
2419
|
+
# Q, K and V without packing:
|
|
2420
|
+
# Input 0 (query) has shape (batch_size, 1, hidden_size)
|
|
2421
|
+
# Input 5 (past_key) if exists has shape (batch_size, num_heads, max_sequence_length, head_size)
|
|
2422
|
+
|
|
2423
|
+
query_shape = self._get_shape(node, 0)
|
|
2424
|
+
if query_shape is not None:
|
|
2425
|
+
output_shape = query_shape
|
|
2426
|
+
output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type
|
|
2427
|
+
assert output_dtype is not None
|
|
2428
|
+
vi = self.known_vi_[node.output[0]]
|
|
2429
|
+
vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape))
|
|
2430
|
+
|
|
2431
|
+
if len(node.output) > 2 and node.output[1] and node.output[2]:
|
|
2432
|
+
past_shape = self._try_get_shape(node, 5)
|
|
2433
|
+
if past_shape is not None:
|
|
2434
|
+
vi = self.known_vi_[node.output[1]]
|
|
2435
|
+
vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, past_shape))
|
|
2436
|
+
vi = self.known_vi_[node.output[2]]
|
|
2437
|
+
vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, past_shape))
|
|
2438
|
+
|
|
2439
|
+
def _infer_UnfoldTensor(self, node): # noqa: N802
|
|
2440
|
+
input_shape = self._get_shape(node, 0)
|
|
2441
|
+
if input_shape is not None:
|
|
2442
|
+
output_shape = input_shape.copy()
|
|
2443
|
+
output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type
|
|
2444
|
+
assert output_dtype is not None
|
|
2445
|
+
|
|
2446
|
+
rank, dim, size, step = len(input_shape), None, None, None
|
|
2447
|
+
for attr in node.attribute:
|
|
2448
|
+
if attr.name == "dim":
|
|
2449
|
+
dim = attr.i
|
|
2450
|
+
dim = rank + dim if dim == -1 else dim
|
|
2451
|
+
elif attr.name == "size":
|
|
2452
|
+
size = attr.i
|
|
2453
|
+
elif attr.name == "step":
|
|
2454
|
+
step = attr.i
|
|
2455
|
+
|
|
2456
|
+
output_shape.append(size)
|
|
2457
|
+
output_shape[dim] = (input_shape[dim] - size) // step + 1
|
|
2458
|
+
|
|
2459
|
+
vi = self.known_vi_[node.output[0]]
|
|
2460
|
+
vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape))
|
|
2461
|
+
|
|
2462
|
+
def _infer_DynamicTimeWarping(self, node): # noqa: N802
|
|
2463
|
+
# Input 0 has shape M x N or 1 x M x N
|
|
2464
|
+
# Output 0 has shape (2, O) where max(M, N) <= O < M + N
|
|
2465
|
+
input_shape = self._get_shape(node, 0)
|
|
2466
|
+
if input_shape is not None:
|
|
2467
|
+
shape_len = len(input_shape)
|
|
2468
|
+
assert shape_len == 2 or shape_len == 3
|
|
2469
|
+
M, N = input_shape[shape_len - 2], input_shape[shape_len - 1] # noqa: N806
|
|
2470
|
+
output_shape = [2, f"max({M}, {N}) <= O < {M} + {N}"]
|
|
2471
|
+
output_dtype = onnx.TensorProto.FLOAT
|
|
2472
|
+
vi = self.known_vi_[node.output[0]]
|
|
2473
|
+
vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape))
|
|
2474
|
+
|
|
2475
|
+
def _infer_FastGelu(self, node): # noqa: N802
|
|
2476
|
+
self._propagate_shape_and_type(node)
|
|
2477
|
+
|
|
2478
|
+
def _infer_Gelu(self, node): # noqa: N802
|
|
2479
|
+
self._propagate_shape_and_type(node)
|
|
2480
|
+
|
|
2481
|
+
def _infer_QuickGelu(self, node): # noqa: N802
|
|
2482
|
+
self._propagate_shape_and_type(node)
|
|
2483
|
+
|
|
2484
|
+
def _infer_GemmFastGelu(self, node): # noqa: N802
|
|
2485
|
+
self._compute_matmul_shape(node)
|
|
2486
|
+
|
|
2487
|
+
def _infer_GemmFloat8(self, node): # noqa: N802
|
|
2488
|
+
self._compute_matmul_shape(node)
|
|
2489
|
+
|
|
2490
|
+
def _infer_LayerNormalization(self, node): # noqa: N802
|
|
2491
|
+
self._propagate_shape_and_type(node)
|
|
2492
|
+
if len(node.output) > 1:
|
|
2493
|
+
axis = get_attribute(node, "axis")
|
|
2494
|
+
if axis is None:
|
|
2495
|
+
axis = -1
|
|
2496
|
+
x_shape = self._get_shape(node, 0)
|
|
2497
|
+
if x_shape is not None:
|
|
2498
|
+
rank = len(x_shape)
|
|
2499
|
+
axis = handle_negative_axis(axis, rank)
|
|
2500
|
+
mean_shape = x_shape[:axis] + [1 for _ in range(rank - axis)]
|
|
2501
|
+
mean_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type
|
|
2502
|
+
if mean_dtype == onnx.TensorProto.FLOAT16 or mean_dtype == onnx.TensorProto.BFLOAT16:
|
|
2503
|
+
mean_dtype = onnx.TensorProto.FLOAT
|
|
2504
|
+
vi = self.known_vi_[node.output[1]]
|
|
2505
|
+
vi.CopyFrom(helper.make_tensor_value_info(node.output[1], mean_dtype, mean_shape))
|
|
2506
|
+
if len(node.output) > 2:
|
|
2507
|
+
vi = self.known_vi_[node.output[2]]
|
|
2508
|
+
vi.CopyFrom(helper.make_tensor_value_info(node.output[2], mean_dtype, mean_shape))
|
|
2509
|
+
|
|
2510
|
+
def _infer_LongformerAttention(self, node): # noqa: N802
|
|
2511
|
+
self._propagate_shape_and_type(node)
|
|
2512
|
+
|
|
2513
|
+
def _infer_EmbedLayerNormalization(self, node): # noqa: N802
|
|
2514
|
+
input_ids_shape = self._get_shape(node, 0)
|
|
2515
|
+
word_embedding_shape = self._get_shape(node, 2)
|
|
2516
|
+
assert len(input_ids_shape) == 2 and len(word_embedding_shape) == 2
|
|
2517
|
+
output_shape = [*input_ids_shape, word_embedding_shape[1]]
|
|
2518
|
+
|
|
2519
|
+
word_embedding_dtype = self.known_vi_[node.input[2]].type.tensor_type.elem_type
|
|
2520
|
+
vi = self.known_vi_[node.output[0]]
|
|
2521
|
+
vi.CopyFrom(helper.make_tensor_value_info(node.output[0], word_embedding_dtype, output_shape))
|
|
2522
|
+
|
|
2523
|
+
if len(node.output) > 1 and node.output[1]:
|
|
2524
|
+
mask_index_shape = [input_ids_shape[0]]
|
|
2525
|
+
vi = self.known_vi_[node.output[1]]
|
|
2526
|
+
vi.CopyFrom(helper.make_tensor_value_info(node.output[1], onnx.TensorProto.INT32, mask_index_shape))
|
|
2527
|
+
|
|
2528
|
+
if len(node.output) > 2:
|
|
2529
|
+
# Optional output of add before layer normalization is done
|
|
2530
|
+
# shape is same as the output
|
|
2531
|
+
vi = self.known_vi_[node.output[2]]
|
|
2532
|
+
vi.CopyFrom(helper.make_tensor_value_info(node.output[2], word_embedding_dtype, output_shape))
|
|
2533
|
+
|
|
2534
|
+
def _infer_SkipLayerNormalization(self, node): # noqa: N802
|
|
2535
|
+
self._propagate_shape_and_type(node)
|
|
2536
|
+
|
|
2537
|
+
# If the SkipLayerNormalization node contains the optional
|
|
2538
|
+
# output for inference, infer the shape and type for it too
|
|
2539
|
+
if len(node.output) > 3:
|
|
2540
|
+
self._propagate_shape_and_type(node, 0, 3)
|
|
2541
|
+
|
|
2542
|
+
def _infer_GroupNorm(self, node): # noqa: N802
|
|
2543
|
+
self._propagate_shape_and_type(node)
|
|
2544
|
+
|
|
2545
|
+
def _infer_PagedAttention(self, node): # noqa: N802
|
|
2546
|
+
self._propagate_shape_and_type(node)
|
|
2547
|
+
|
|
2548
|
+
def _infer_GroupQueryAttention(self, node): # noqa: N802
|
|
2549
|
+
output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type
|
|
2550
|
+
|
|
2551
|
+
past_shape = self._try_get_shape(node, 3)
|
|
2552
|
+
if past_shape is not None:
|
|
2553
|
+
# When past and present has the maximum sequence length, we can propagate the shape from past to present.
|
|
2554
|
+
# Note that GQA also supports different sequence lengths for past and present, but it is rarely used.
|
|
2555
|
+
vi = self.known_vi_[node.output[1]]
|
|
2556
|
+
vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, past_shape))
|
|
2557
|
+
vi = self.known_vi_[node.output[2]]
|
|
2558
|
+
vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, past_shape))
|
|
2559
|
+
|
|
2560
|
+
if node.input[1] != "" and node.input[2] != "":
|
|
2561
|
+
self._propagate_shape_and_type(node, 0, 0)
|
|
2562
|
+
else:
|
|
2563
|
+
# combined qkv: (batch_size, sequence_length, num_heads * head_size + 2 * kv_num_heads * head_size)
|
|
2564
|
+
assert node.input[1] == "" and node.input[2] == ""
|
|
2565
|
+
num_heads = get_attribute(node, "num_heads")
|
|
2566
|
+
kv_num_heads = get_attribute(node, "kv_num_heads")
|
|
2567
|
+
query_shape = self._get_shape(node, 0)
|
|
2568
|
+
if query_shape is not None:
|
|
2569
|
+
hidden_size = query_shape[2]
|
|
2570
|
+
if isinstance(hidden_size, int):
|
|
2571
|
+
head_size = int(hidden_size / (num_heads + 2 * kv_num_heads))
|
|
2572
|
+
query_shape[2] = num_heads * head_size
|
|
2573
|
+
vi = self.known_vi_[node.output[0]]
|
|
2574
|
+
vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, query_shape))
|
|
2575
|
+
|
|
2576
|
+
def _infer_SparseAttention(self, node): # noqa: N802
|
|
2577
|
+
self._infer_GroupQueryAttention(node)
|
|
2578
|
+
|
|
2579
|
+
def _infer_SkipGroupNorm(self, node): # noqa: N802
|
|
2580
|
+
self._propagate_shape_and_type(node, 0, 0)
|
|
2581
|
+
if len(node.output) > 1:
|
|
2582
|
+
self._propagate_shape_and_type(node, 0, 1)
|
|
2583
|
+
|
|
2584
|
+
def _infer_BiasSplitGelu(self, node): # noqa: N802
|
|
2585
|
+
input_shape = self._get_shape(node, 0)
|
|
2586
|
+
bias_shape = self._get_shape(node, 1)
|
|
2587
|
+
if input_shape and bias_shape and isinstance(bias_shape[0], int):
|
|
2588
|
+
output_shape = input_shape
|
|
2589
|
+
output_shape[2] = int(bias_shape[0] / 2)
|
|
2590
|
+
vi = self.known_vi_[node.output[0]]
|
|
2591
|
+
output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type
|
|
2592
|
+
vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, output_shape))
|
|
2593
|
+
|
|
2594
|
+
def _infer_BiasAdd(self, node): # noqa: N802
|
|
2595
|
+
self._propagate_shape_and_type(node)
|
|
2596
|
+
|
|
2597
|
+
def _infer_RotaryEmbedding(self, node): # noqa: N802
|
|
2598
|
+
if len(node.output) == 1:
|
|
2599
|
+
self._propagate_shape_and_type(node)
|
|
2600
|
+
elif len(node.output) == 2:
|
|
2601
|
+
# Extraneous constant nodes outputted by RotaryEmbedding function made with `export_modules_as_functions`
|
|
2602
|
+
self._propagate_shape_and_type(node, input_index=1, output_index=0)
|
|
2603
|
+
self._propagate_shape_and_type(node, input_index=0, output_index=1) # true output
|
|
2604
|
+
elif len(node.output) == 3:
|
|
2605
|
+
# Extraneous constant nodes outputted by RotaryEmbedding function made with `export_modules_as_functions`
|
|
2606
|
+
self._propagate_shape_and_type(node, input_index=1, output_index=0)
|
|
2607
|
+
self._propagate_shape_and_type(node, input_index=1, output_index=1)
|
|
2608
|
+
self._propagate_shape_and_type(node, input_index=0, output_index=2) # true output
|
|
2609
|
+
|
|
2610
|
+
def _infer_PythonOp(self, node): # noqa: N802
|
|
2611
|
+
output_tensor_types = get_attribute(node, "output_tensor_types")
|
|
2612
|
+
assert output_tensor_types, f"PythonOp '{node.name}' has no output_tensor_types attribute."
|
|
2613
|
+
output_tensor_ranks = get_attribute(node, "output_tensor_ranks")
|
|
2614
|
+
assert output_tensor_ranks, f"PythonOp '{node.name}' has no output_tensor_ranks attribute."
|
|
2615
|
+
|
|
2616
|
+
from onnxruntime.capi._pybind_state import get_shape_inference_function # noqa: PLC0415
|
|
2617
|
+
|
|
2618
|
+
func_name = get_attribute(node, "func_name").decode()
|
|
2619
|
+
shape_inferer = get_shape_inference_function(func_name)
|
|
2620
|
+
|
|
2621
|
+
# Set the context output separately.
|
|
2622
|
+
# The first output is torch.autograd.Function''s context.
|
|
2623
|
+
vi = self.known_vi_[node.output[0]]
|
|
2624
|
+
vi.CopyFrom(helper.make_tensor_value_info(node.output[0], onnx.TensorProto.INT64, []))
|
|
2625
|
+
|
|
2626
|
+
if shape_inferer is not None:
|
|
2627
|
+
input_shapes = []
|
|
2628
|
+
input_dtypes = []
|
|
2629
|
+
for input_index in range(len(node.input)):
|
|
2630
|
+
shape = self._get_shape(node, input_index)
|
|
2631
|
+
input_shapes.append(shape)
|
|
2632
|
+
input_dtype = self.known_vi_[node.input[input_index]].type.tensor_type.elem_type
|
|
2633
|
+
input_dtypes.append(input_dtype)
|
|
2634
|
+
output_shapes, output_dtypes = shape_inferer(node, input_shapes, input_dtypes)
|
|
2635
|
+
assert len(output_shapes) == len(output_dtypes) == (len(node.output) - 1), (
|
|
2636
|
+
f"PythonOp '{func_name}' returned {len(output_shapes)} shapes and {len(output_dtypes)} dtypes, "
|
|
2637
|
+
f"but expected {len(node.output) - 1} outputs."
|
|
2638
|
+
)
|
|
2639
|
+
for i in range(len(node.output) - 1):
|
|
2640
|
+
output_index = i + 1
|
|
2641
|
+
vi = self.known_vi_[node.output[output_index]]
|
|
2642
|
+
vi.CopyFrom(
|
|
2643
|
+
helper.make_tensor_value_info(node.output[output_index], output_dtypes[i], output_shapes[i])
|
|
2644
|
+
)
|
|
2645
|
+
else:
|
|
2646
|
+
# General shape inference for PythonOp.
|
|
2647
|
+
# Outputs after torch.autograd.Function's context are tensors.
|
|
2648
|
+
# We assume their ranks are fixed for different model inputs.
|
|
2649
|
+
for i in range(len(node.output) - 1):
|
|
2650
|
+
# Process the i-th tensor outputs.
|
|
2651
|
+
vi = self.known_vi_[node.output[i + 1]]
|
|
2652
|
+
sympy_shape = self._new_symbolic_shape(output_tensor_ranks[i], node)
|
|
2653
|
+
shape = get_shape_from_sympy_shape(sympy_shape)
|
|
2654
|
+
value_info = helper.make_tensor_value_info(node.output[i + 1], output_tensor_types[i], shape)
|
|
2655
|
+
vi.CopyFrom(value_info)
|
|
2656
|
+
|
|
2657
|
+
def _propagate_shape_and_type(self, node, input_index=0, output_index=0):
|
|
2658
|
+
shape = self._get_shape(node, input_index)
|
|
2659
|
+
output_dtype = self.known_vi_[node.input[input_index]].type.tensor_type.elem_type
|
|
2660
|
+
vi = self.known_vi_[node.output[output_index]]
|
|
2661
|
+
vi.CopyFrom(helper.make_tensor_value_info(node.output[output_index], output_dtype, shape))
|
|
2662
|
+
|
|
2663
|
+
def _is_none_dim(self, dim_value):
|
|
2664
|
+
if type(dim_value) != str: # noqa: E721
|
|
2665
|
+
return False
|
|
2666
|
+
if "unk__" not in dim_value:
|
|
2667
|
+
return False
|
|
2668
|
+
if dim_value in self.symbolic_dims_:
|
|
2669
|
+
return False
|
|
2670
|
+
return True
|
|
2671
|
+
|
|
2672
|
+
def _is_shape_contains_none_dim(self, out_shape):
|
|
2673
|
+
for out in out_shape:
|
|
2674
|
+
if self._is_none_dim(out):
|
|
2675
|
+
return out
|
|
2676
|
+
return None
|
|
2677
|
+
|
|
2678
|
+
def _infer_impl(self, start_sympy_data=None):
|
|
2679
|
+
self.sympy_data_ = start_sympy_data or {}
|
|
2680
|
+
self.out_mp_.graph.ClearField("value_info")
|
|
2681
|
+
self._apply_suggested_merge(graph_input_only=True)
|
|
2682
|
+
self.input_symbols_ = set()
|
|
2683
|
+
for i in self.out_mp_.graph.input:
|
|
2684
|
+
input_shape = get_shape_from_value_info(i)
|
|
2685
|
+
if input_shape is None:
|
|
2686
|
+
continue
|
|
2687
|
+
|
|
2688
|
+
if is_sequence(i.type):
|
|
2689
|
+
input_dims = i.type.sequence_type.elem_type.tensor_type.shape.dim
|
|
2690
|
+
else:
|
|
2691
|
+
input_dims = i.type.tensor_type.shape.dim
|
|
2692
|
+
|
|
2693
|
+
for i_dim, dim in enumerate(input_shape):
|
|
2694
|
+
if dim is None:
|
|
2695
|
+
# some models use None for symbolic dim in input, replace it with a string
|
|
2696
|
+
input_dims[i_dim].dim_param = str(self._new_symbolic_dim(i.name, i_dim))
|
|
2697
|
+
|
|
2698
|
+
self.input_symbols_.update([d for d in input_shape if type(d) is str])
|
|
2699
|
+
|
|
2700
|
+
for s in self.input_symbols_:
|
|
2701
|
+
if s in self.suggested_merge_:
|
|
2702
|
+
s_merge = self.suggested_merge_[s]
|
|
2703
|
+
assert s_merge in self.symbolic_dims_
|
|
2704
|
+
self.symbolic_dims_[s] = self.symbolic_dims_[s_merge]
|
|
2705
|
+
else:
|
|
2706
|
+
# Since inputs are not produced by other ops, we can assume positivity
|
|
2707
|
+
self.symbolic_dims_[s] = sympy.Symbol(s, integer=True, positive=True)
|
|
2708
|
+
# create a temporary ModelProto for single node inference
|
|
2709
|
+
# note that we remove initializer to have faster inference
|
|
2710
|
+
# for tensor ops like Reshape/Tile/Expand that read initializer, we need to do sympy computation based inference anyways
|
|
2711
|
+
self.tmp_mp_ = onnx.ModelProto()
|
|
2712
|
+
self.tmp_mp_.CopyFrom(self.out_mp_)
|
|
2713
|
+
self.tmp_mp_.graph.ClearField("initializer")
|
|
2714
|
+
|
|
2715
|
+
# compute prerequesite for node for topological sort
|
|
2716
|
+
# node with subgraphs may have dependency on implicit inputs, which will affect topological sort
|
|
2717
|
+
prereq_for_node = {} # map from node to all its inputs, including implicit ones in subgraph
|
|
2718
|
+
|
|
2719
|
+
def get_prereq(node):
|
|
2720
|
+
names = {i for i in node.input if i}
|
|
2721
|
+
subgraphs = []
|
|
2722
|
+
if node.op_type == "If":
|
|
2723
|
+
subgraphs = [
|
|
2724
|
+
get_attribute(node, "then_branch"),
|
|
2725
|
+
get_attribute(node, "else_branch"),
|
|
2726
|
+
]
|
|
2727
|
+
elif node.op_type in ["Loop", "Scan"]:
|
|
2728
|
+
subgraphs = [get_attribute(node, "body")]
|
|
2729
|
+
for g in subgraphs:
|
|
2730
|
+
g_outputs_and_initializers = {i.name for i in g.initializer}
|
|
2731
|
+
g_prereq = set()
|
|
2732
|
+
for n in g.node:
|
|
2733
|
+
g_outputs_and_initializers.update(n.output)
|
|
2734
|
+
for n in g.node:
|
|
2735
|
+
g_prereq.update([i for i in get_prereq(n) if i not in g_outputs_and_initializers])
|
|
2736
|
+
names.update(g_prereq)
|
|
2737
|
+
# remove subgraph inputs from g_prereq since those are local-only
|
|
2738
|
+
for i in g.input:
|
|
2739
|
+
names.discard(i.name)
|
|
2740
|
+
return names
|
|
2741
|
+
|
|
2742
|
+
for n in self.tmp_mp_.graph.node:
|
|
2743
|
+
prereq_for_node[n.output[0]] = get_prereq(n)
|
|
2744
|
+
|
|
2745
|
+
# topological sort nodes, note there might be dead nodes so we check if all graph outputs are reached to terminate
|
|
2746
|
+
sorted_nodes = []
|
|
2747
|
+
sorted_known_vi = {i.name for i in list(self.out_mp_.graph.input) + list(self.out_mp_.graph.initializer)}
|
|
2748
|
+
if any(o.name in sorted_known_vi for o in self.out_mp_.graph.output):
|
|
2749
|
+
# Loop/Scan will have some graph output in graph inputs, so don't do topological sort
|
|
2750
|
+
sorted_nodes = self.out_mp_.graph.node
|
|
2751
|
+
else:
|
|
2752
|
+
while not all(o.name in sorted_known_vi for o in self.out_mp_.graph.output):
|
|
2753
|
+
old_sorted_nodes_len = len(sorted_nodes)
|
|
2754
|
+
for node in self.out_mp_.graph.node:
|
|
2755
|
+
if (node.output[0] not in sorted_known_vi) and all(
|
|
2756
|
+
i in sorted_known_vi for i in prereq_for_node[node.output[0]] if i
|
|
2757
|
+
):
|
|
2758
|
+
sorted_known_vi.update(node.output)
|
|
2759
|
+
sorted_nodes.append(node)
|
|
2760
|
+
if old_sorted_nodes_len == len(sorted_nodes) and not all(
|
|
2761
|
+
o.name in sorted_known_vi for o in self.out_mp_.graph.output
|
|
2762
|
+
):
|
|
2763
|
+
raise Exception("Invalid model with cyclic graph")
|
|
2764
|
+
|
|
2765
|
+
for node in sorted_nodes:
|
|
2766
|
+
assert all(i in self.known_vi_ for i in node.input if i)
|
|
2767
|
+
self._onnx_infer_single_node(node)
|
|
2768
|
+
known_aten_op = False
|
|
2769
|
+
if node.op_type in self.dispatcher_:
|
|
2770
|
+
self.dispatcher_[node.op_type](node)
|
|
2771
|
+
elif node.op_type in ["ConvTranspose"]:
|
|
2772
|
+
# onnx shape inference ops like ConvTranspose may have empty shape for symbolic input
|
|
2773
|
+
# before adding symbolic compute for them
|
|
2774
|
+
# mark the output type as UNDEFINED to allow guessing of rank
|
|
2775
|
+
vi = self.known_vi_[node.output[0]]
|
|
2776
|
+
if len(vi.type.tensor_type.shape.dim) == 0:
|
|
2777
|
+
vi.type.tensor_type.elem_type = onnx.TensorProto.UNDEFINED
|
|
2778
|
+
elif node.op_type == "ATen" and node.domain == "org.pytorch.aten":
|
|
2779
|
+
for attr in node.attribute:
|
|
2780
|
+
# TODO: Is overload_name needed?
|
|
2781
|
+
if attr.name == "operator":
|
|
2782
|
+
aten_op_name = attr.s.decode("utf-8") if isinstance(attr.s, bytes) else attr.s
|
|
2783
|
+
if aten_op_name in self.aten_op_dispatcher_:
|
|
2784
|
+
known_aten_op = True
|
|
2785
|
+
self.aten_op_dispatcher_[aten_op_name](node)
|
|
2786
|
+
break
|
|
2787
|
+
|
|
2788
|
+
if self.verbose_ > 2:
|
|
2789
|
+
logger.debug(node.op_type + ": " + node.name) # noqa: G003
|
|
2790
|
+
for i, name in enumerate(node.input):
|
|
2791
|
+
logger.debug(" Input %s: %s %s", i, name, "initializer" if name in self.initializers_ else "")
|
|
2792
|
+
|
|
2793
|
+
# onnx automatically merge dims with value, i.e. Mul(['aaa', 'bbb'], [1000, 1]) -> [1000, 'bbb']
|
|
2794
|
+
# symbolic shape inference needs to apply merge of 'aaa' -> 1000 in this case
|
|
2795
|
+
if node.op_type in [
|
|
2796
|
+
"Add",
|
|
2797
|
+
"Sub",
|
|
2798
|
+
"Mul",
|
|
2799
|
+
"Div",
|
|
2800
|
+
"MatMul",
|
|
2801
|
+
"MatMulInteger",
|
|
2802
|
+
"MatMulInteger16",
|
|
2803
|
+
"Where",
|
|
2804
|
+
"Sum",
|
|
2805
|
+
]:
|
|
2806
|
+
vi = self.known_vi_[node.output[0]]
|
|
2807
|
+
out_rank = len(get_shape_from_type_proto(vi.type))
|
|
2808
|
+
in_shapes = [self._get_shape(node, i) for i in range(len(node.input))]
|
|
2809
|
+
for d in range(out_rank - (2 if node.op_type in ["MatMul", "MatMulInteger", "MatMulInteger16"] else 0)):
|
|
2810
|
+
in_dims = [s[len(s) - out_rank + d] for s in in_shapes if len(s) + d >= out_rank]
|
|
2811
|
+
if len(in_dims) > 1:
|
|
2812
|
+
self._check_merged_dims(in_dims, allow_broadcast=True)
|
|
2813
|
+
|
|
2814
|
+
for i_o in range(len(node.output)):
|
|
2815
|
+
# Special cases:
|
|
2816
|
+
# 1) We do not care about the training related outputs of SkipLayerNormalization
|
|
2817
|
+
# 2) We do not care about the extraneous constant outputs in RotaryEmbedding because
|
|
2818
|
+
# the RotaryEmbedding op created during export can be replaced by the RotaryEmbedding
|
|
2819
|
+
# contrib op
|
|
2820
|
+
if (
|
|
2821
|
+
node.op_type == "SkipLayerNormalization" or node.op_type == "SkipSimplifiedLayerNormalization"
|
|
2822
|
+
) and i_o in [1, 2]:
|
|
2823
|
+
continue
|
|
2824
|
+
if node.op_type == "RotaryEmbedding" and len(node.output) > 1:
|
|
2825
|
+
# Skip symbolic shape inference for RotaryEmbedding functions that have extraneous outputs
|
|
2826
|
+
# generated by `export_modules_as_functions`
|
|
2827
|
+
continue
|
|
2828
|
+
|
|
2829
|
+
vi = self.known_vi_[node.output[i_o]]
|
|
2830
|
+
out_type = vi.type
|
|
2831
|
+
out_type_kind = out_type.WhichOneof("value")
|
|
2832
|
+
|
|
2833
|
+
# do not process shape for non-tensors
|
|
2834
|
+
if out_type_kind not in ["tensor_type", "sparse_tensor_type", None]:
|
|
2835
|
+
if self.verbose_ > 2:
|
|
2836
|
+
if out_type_kind == "sequence_type":
|
|
2837
|
+
seq_cls_type = out_type.sequence_type.elem_type.WhichOneof("value")
|
|
2838
|
+
if seq_cls_type == "tensor_type":
|
|
2839
|
+
logger.debug(
|
|
2840
|
+
" {}: sequence of {} {}".format( # noqa: G001
|
|
2841
|
+
node.output[i_o],
|
|
2842
|
+
str(get_shape_from_value_info(vi)),
|
|
2843
|
+
onnx.TensorProto.DataType.Name(
|
|
2844
|
+
vi.type.sequence_type.elem_type.tensor_type.elem_type
|
|
2845
|
+
),
|
|
2846
|
+
)
|
|
2847
|
+
)
|
|
2848
|
+
else:
|
|
2849
|
+
logger.debug(f" {node.output[i_o]}: sequence of {seq_cls_type}")
|
|
2850
|
+
else:
|
|
2851
|
+
logger.debug(f" {node.output[i_o]}: {out_type_kind}")
|
|
2852
|
+
continue
|
|
2853
|
+
|
|
2854
|
+
out_shape = get_shape_from_value_info(vi)
|
|
2855
|
+
out_type_undefined = out_type.tensor_type.elem_type == onnx.TensorProto.UNDEFINED
|
|
2856
|
+
if self.verbose_ > 2:
|
|
2857
|
+
logger.debug(
|
|
2858
|
+
f" {node.output[i_o]}: {out_shape!s} {onnx.TensorProto.DataType.Name(vi.type.tensor_type.elem_type)}"
|
|
2859
|
+
)
|
|
2860
|
+
if node.output[i_o] in self.sympy_data_:
|
|
2861
|
+
logger.debug(" Sympy Data: " + str(self.sympy_data_[node.output[i_o]])) # noqa: G003
|
|
2862
|
+
|
|
2863
|
+
# onnx >= 1.11.0, use unk__#index instead of None when the shape dim is uncertain
|
|
2864
|
+
if (
|
|
2865
|
+
out_shape is not None and (None in out_shape or self._is_shape_contains_none_dim(out_shape))
|
|
2866
|
+
) or out_type_undefined:
|
|
2867
|
+
if self.auto_merge_:
|
|
2868
|
+
if node.op_type in [
|
|
2869
|
+
"Add",
|
|
2870
|
+
"Sub",
|
|
2871
|
+
"Mul",
|
|
2872
|
+
"Div",
|
|
2873
|
+
"MatMul",
|
|
2874
|
+
"MatMulInteger",
|
|
2875
|
+
"MatMulInteger16",
|
|
2876
|
+
"Concat",
|
|
2877
|
+
"Where",
|
|
2878
|
+
"Sum",
|
|
2879
|
+
"Equal",
|
|
2880
|
+
"Less",
|
|
2881
|
+
"Greater",
|
|
2882
|
+
"LessOrEqual",
|
|
2883
|
+
"GreaterOrEqual",
|
|
2884
|
+
"Min",
|
|
2885
|
+
"Max",
|
|
2886
|
+
]:
|
|
2887
|
+
shapes = [self._get_shape(node, i) for i in range(len(node.input))]
|
|
2888
|
+
if node.op_type in [
|
|
2889
|
+
"MatMul",
|
|
2890
|
+
"MatMulInteger",
|
|
2891
|
+
"MatMulInteger16",
|
|
2892
|
+
]:
|
|
2893
|
+
if None in out_shape or self._is_shape_contains_none_dim(out_shape):
|
|
2894
|
+
if None in out_shape:
|
|
2895
|
+
idx = out_shape.index(None)
|
|
2896
|
+
else:
|
|
2897
|
+
idx = out_shape.index(self._is_shape_contains_none_dim(out_shape))
|
|
2898
|
+
dim_idx = [len(s) - len(out_shape) + idx for s in shapes]
|
|
2899
|
+
# only support auto merge for MatMul for dim < rank-2 when rank > 2
|
|
2900
|
+
assert len(shapes[0]) > 2 and dim_idx[0] < len(shapes[0]) - 2
|
|
2901
|
+
assert len(shapes[1]) > 2 and dim_idx[1] < len(shapes[1]) - 2
|
|
2902
|
+
elif node.op_type == "Expand":
|
|
2903
|
+
# auto merge for cases like Expand([min(batch, 1), min(seq, 512)], [batch, seq])
|
|
2904
|
+
shapes = [
|
|
2905
|
+
self._get_shape(node, 0),
|
|
2906
|
+
self._get_value(node, 1),
|
|
2907
|
+
]
|
|
2908
|
+
else:
|
|
2909
|
+
shapes = []
|
|
2910
|
+
|
|
2911
|
+
if shapes:
|
|
2912
|
+
for idx in range(len(out_shape)):
|
|
2913
|
+
if out_shape[idx] is not None and not self._is_none_dim(out_shape[idx]):
|
|
2914
|
+
continue
|
|
2915
|
+
# note that the broadcasting rule aligns from right to left
|
|
2916
|
+
# if a tensor has a lower rank (dim_idx[idx] < 0), it would automatically broadcast and need no merge
|
|
2917
|
+
dim_idx = [len(s) - len(out_shape) + idx for s in shapes]
|
|
2918
|
+
if len(dim_idx) > 0:
|
|
2919
|
+
self._add_suggested_merge(
|
|
2920
|
+
[
|
|
2921
|
+
s[i] if is_literal(s[i]) else str(s[i])
|
|
2922
|
+
for s, i in zip(shapes, dim_idx, strict=False)
|
|
2923
|
+
if i >= 0
|
|
2924
|
+
]
|
|
2925
|
+
)
|
|
2926
|
+
self.run_ = True
|
|
2927
|
+
else:
|
|
2928
|
+
self.run_ = False
|
|
2929
|
+
else:
|
|
2930
|
+
self.run_ = False
|
|
2931
|
+
|
|
2932
|
+
# create new dynamic dims for ops not handled by symbolic shape inference
|
|
2933
|
+
if self.run_ is False and node.op_type not in self.dispatcher_ and not known_aten_op:
|
|
2934
|
+
is_unknown_op = out_type_undefined and (out_shape is None or len(out_shape) == 0)
|
|
2935
|
+
if is_unknown_op:
|
|
2936
|
+
# unknown op to ONNX, maybe from higher opset or other domain
|
|
2937
|
+
# only guess the output rank from input 0 when using guess_output_rank option
|
|
2938
|
+
out_rank = self._get_shape_rank(node, 0) if self.guess_output_rank_ else -1
|
|
2939
|
+
else:
|
|
2940
|
+
# valid ONNX op, but not handled by symbolic shape inference, just assign dynamic shape
|
|
2941
|
+
out_rank = len(out_shape)
|
|
2942
|
+
|
|
2943
|
+
if out_rank >= 0:
|
|
2944
|
+
new_shape = self._new_symbolic_shape(out_rank, node, i_o)
|
|
2945
|
+
if out_type_undefined:
|
|
2946
|
+
# guess output data type from input vi if not defined
|
|
2947
|
+
out_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type
|
|
2948
|
+
else:
|
|
2949
|
+
# otherwise, use original data type
|
|
2950
|
+
out_dtype = vi.type.tensor_type.elem_type
|
|
2951
|
+
vi.CopyFrom(
|
|
2952
|
+
helper.make_tensor_value_info(
|
|
2953
|
+
vi.name,
|
|
2954
|
+
out_dtype,
|
|
2955
|
+
get_shape_from_sympy_shape(new_shape),
|
|
2956
|
+
)
|
|
2957
|
+
)
|
|
2958
|
+
|
|
2959
|
+
if self.verbose_ > 0:
|
|
2960
|
+
if is_unknown_op:
|
|
2961
|
+
logger.debug(
|
|
2962
|
+
f"Possible unknown op: {node.op_type} node: {node.name}, guessing {vi.name} shape"
|
|
2963
|
+
)
|
|
2964
|
+
if self.verbose_ > 2:
|
|
2965
|
+
logger.debug(f" {node.output[i_o]}: {new_shape!s} {vi.type.tensor_type.elem_type}")
|
|
2966
|
+
|
|
2967
|
+
self.run_ = True
|
|
2968
|
+
continue # continue the inference after guess, no need to stop as no merge is needed
|
|
2969
|
+
|
|
2970
|
+
if self.verbose_ > 0 or not self.auto_merge_ or out_type_undefined:
|
|
2971
|
+
logger.debug("Stopping at incomplete shape inference at %s: %s", node.op_type, node.name)
|
|
2972
|
+
logger.debug("node inputs:")
|
|
2973
|
+
for i in node.input:
|
|
2974
|
+
if i in self.known_vi_:
|
|
2975
|
+
logger.debug(self.known_vi_[i])
|
|
2976
|
+
else:
|
|
2977
|
+
logger.debug(f"not in known_vi_ for {i}")
|
|
2978
|
+
logger.debug("node outputs:")
|
|
2979
|
+
for o in node.output:
|
|
2980
|
+
if o in self.known_vi_:
|
|
2981
|
+
logger.debug(self.known_vi_[o])
|
|
2982
|
+
else:
|
|
2983
|
+
logger.debug(f"not in known_vi_ for {o}")
|
|
2984
|
+
if self.auto_merge_ and not out_type_undefined:
|
|
2985
|
+
logger.debug("Merging: " + str(self.suggested_merge_)) # noqa: G003
|
|
2986
|
+
return False
|
|
2987
|
+
|
|
2988
|
+
self.run_ = False
|
|
2989
|
+
return True
|
|
2990
|
+
|
|
2991
|
+
def _update_output_from_vi(self):
|
|
2992
|
+
for output in self.out_mp_.graph.output:
|
|
2993
|
+
if output.name in self.known_vi_:
|
|
2994
|
+
output.CopyFrom(self.known_vi_[output.name])
|
|
2995
|
+
|
|
2996
|
+
@staticmethod
|
|
2997
|
+
def infer_shapes(in_mp, int_max=2**31 - 1, auto_merge=False, guess_output_rank=False, verbose=0):
|
|
2998
|
+
onnx_opset = get_opset(in_mp)
|
|
2999
|
+
if (not onnx_opset) or onnx_opset < 7:
|
|
3000
|
+
logger.warning("Only support models of onnx opset 7 and above.")
|
|
3001
|
+
return None
|
|
3002
|
+
symbolic_shape_inference = SymbolicShapeInference(int_max, auto_merge, guess_output_rank, verbose)
|
|
3003
|
+
all_shapes_inferred = False
|
|
3004
|
+
symbolic_shape_inference._preprocess(in_mp)
|
|
3005
|
+
while symbolic_shape_inference.run_:
|
|
3006
|
+
all_shapes_inferred = symbolic_shape_inference._infer_impl()
|
|
3007
|
+
symbolic_shape_inference._update_output_from_vi()
|
|
3008
|
+
if not all_shapes_inferred:
|
|
3009
|
+
onnx.save_model(symbolic_shape_inference.out_mp_, "sym_shape_infer_temp.onnx", save_as_external_data=True)
|
|
3010
|
+
raise Exception("Incomplete symbolic shape inference")
|
|
3011
|
+
return symbolic_shape_inference.out_mp_
|
|
3012
|
+
|
|
3013
|
+
|
|
3014
|
+
def parse_arguments():
|
|
3015
|
+
parser = argparse.ArgumentParser()
|
|
3016
|
+
parser.add_argument("--input", required=True, help="The input model file")
|
|
3017
|
+
parser.add_argument("--output", help="The output model file")
|
|
3018
|
+
parser.add_argument(
|
|
3019
|
+
"--auto_merge",
|
|
3020
|
+
help="Automatically merge symbolic dims when confliction happens",
|
|
3021
|
+
action="store_true",
|
|
3022
|
+
default=False,
|
|
3023
|
+
)
|
|
3024
|
+
parser.add_argument(
|
|
3025
|
+
"--int_max",
|
|
3026
|
+
help="maximum value for integer to be treated as boundless for ops like slice",
|
|
3027
|
+
type=int,
|
|
3028
|
+
default=2**31 - 1,
|
|
3029
|
+
)
|
|
3030
|
+
parser.add_argument(
|
|
3031
|
+
"--guess_output_rank",
|
|
3032
|
+
help="guess output rank to be the same as input 0 for unknown ops",
|
|
3033
|
+
action="store_true",
|
|
3034
|
+
default=False,
|
|
3035
|
+
)
|
|
3036
|
+
parser.add_argument(
|
|
3037
|
+
"--verbose",
|
|
3038
|
+
help="Prints detailed logs of inference, 0: turn off, 1: warnings, 3: detailed",
|
|
3039
|
+
type=int,
|
|
3040
|
+
default=0,
|
|
3041
|
+
)
|
|
3042
|
+
parser.add_argument(
|
|
3043
|
+
"--save_as_external_data",
|
|
3044
|
+
help="Saving an ONNX model to external data",
|
|
3045
|
+
action="store_true",
|
|
3046
|
+
default=False,
|
|
3047
|
+
)
|
|
3048
|
+
parser.add_argument(
|
|
3049
|
+
"--all_tensors_to_one_file",
|
|
3050
|
+
help="Saving all the external data to one file",
|
|
3051
|
+
action="store_true",
|
|
3052
|
+
default=False,
|
|
3053
|
+
)
|
|
3054
|
+
parser.add_argument(
|
|
3055
|
+
"--external_data_location",
|
|
3056
|
+
help="The file location to save the external file",
|
|
3057
|
+
default="./",
|
|
3058
|
+
)
|
|
3059
|
+
parser.add_argument(
|
|
3060
|
+
"--external_data_size_threshold",
|
|
3061
|
+
help="The size threshold for external data",
|
|
3062
|
+
type=int,
|
|
3063
|
+
default=1024,
|
|
3064
|
+
)
|
|
3065
|
+
return parser.parse_args()
|
|
3066
|
+
|
|
3067
|
+
|
|
3068
|
+
if __name__ == "__main__":
|
|
3069
|
+
args = parse_arguments()
|
|
3070
|
+
logger.info("input model: " + args.input) # noqa: G003
|
|
3071
|
+
if args.output:
|
|
3072
|
+
logger.info("output model " + args.output) # noqa: G003
|
|
3073
|
+
logger.info("Doing symbolic shape inference...")
|
|
3074
|
+
out_mp = SymbolicShapeInference.infer_shapes(
|
|
3075
|
+
onnx.load(args.input),
|
|
3076
|
+
args.int_max,
|
|
3077
|
+
args.auto_merge,
|
|
3078
|
+
args.guess_output_rank,
|
|
3079
|
+
args.verbose,
|
|
3080
|
+
)
|
|
3081
|
+
if args.output and out_mp:
|
|
3082
|
+
if args.save_as_external_data:
|
|
3083
|
+
onnx.save_model(
|
|
3084
|
+
out_mp,
|
|
3085
|
+
args.output,
|
|
3086
|
+
save_as_external_data=True,
|
|
3087
|
+
all_tensors_to_one_file=args.all_tensors_to_one_file,
|
|
3088
|
+
location=args.external_data_location,
|
|
3089
|
+
size_threshold=args.external_data_size_threshold,
|
|
3090
|
+
convert_attribute=False,
|
|
3091
|
+
)
|
|
3092
|
+
else:
|
|
3093
|
+
onnx.save(out_mp, args.output)
|
|
3094
|
+
logger.info("Done!")
|