onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (322) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6121 -0
  4. onnxruntime/__init__.py +418 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +175 -0
  7. onnxruntime/backend/backend_rep.py +52 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/build_and_package_info.py +2 -0
  13. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  14. onnxruntime/capi/onnxruntime.dll +0 -0
  15. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  16. onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
  17. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  18. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  19. onnxruntime/capi/onnxruntime_validation.py +154 -0
  20. onnxruntime/capi/version_info.py +2 -0
  21. onnxruntime/datasets/__init__.py +18 -0
  22. onnxruntime/datasets/logreg_iris.onnx +0 -0
  23. onnxruntime/datasets/mul_1.onnx +0 -0
  24. onnxruntime/datasets/sigmoid.onnx +13 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  27. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  28. onnxruntime/quantization/__init__.py +19 -0
  29. onnxruntime/quantization/base_quantizer.py +529 -0
  30. onnxruntime/quantization/calibrate.py +1267 -0
  31. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  32. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  33. onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
  34. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  35. onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
  36. onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
  37. onnxruntime/quantization/fusions/__init__.py +4 -0
  38. onnxruntime/quantization/fusions/fusion.py +311 -0
  39. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  40. onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
  41. onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
  42. onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
  43. onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
  44. onnxruntime/quantization/neural_compressor/__init__.py +1 -0
  45. onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
  46. onnxruntime/quantization/neural_compressor/util.py +80 -0
  47. onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
  48. onnxruntime/quantization/onnx_model.py +600 -0
  49. onnxruntime/quantization/onnx_quantizer.py +1163 -0
  50. onnxruntime/quantization/operators/__init__.py +2 -0
  51. onnxruntime/quantization/operators/activation.py +119 -0
  52. onnxruntime/quantization/operators/argmax.py +18 -0
  53. onnxruntime/quantization/operators/attention.py +73 -0
  54. onnxruntime/quantization/operators/base_operator.py +26 -0
  55. onnxruntime/quantization/operators/binary_op.py +72 -0
  56. onnxruntime/quantization/operators/concat.py +62 -0
  57. onnxruntime/quantization/operators/conv.py +260 -0
  58. onnxruntime/quantization/operators/direct_q8.py +78 -0
  59. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  60. onnxruntime/quantization/operators/gather.py +64 -0
  61. onnxruntime/quantization/operators/gavgpool.py +62 -0
  62. onnxruntime/quantization/operators/gemm.py +172 -0
  63. onnxruntime/quantization/operators/lstm.py +121 -0
  64. onnxruntime/quantization/operators/matmul.py +231 -0
  65. onnxruntime/quantization/operators/maxpool.py +34 -0
  66. onnxruntime/quantization/operators/norm.py +40 -0
  67. onnxruntime/quantization/operators/pad.py +172 -0
  68. onnxruntime/quantization/operators/pooling.py +67 -0
  69. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  70. onnxruntime/quantization/operators/resize.py +34 -0
  71. onnxruntime/quantization/operators/softmax.py +74 -0
  72. onnxruntime/quantization/operators/split.py +63 -0
  73. onnxruntime/quantization/operators/where.py +87 -0
  74. onnxruntime/quantization/preprocess.py +141 -0
  75. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  76. onnxruntime/quantization/qdq_quantizer.py +1477 -0
  77. onnxruntime/quantization/quant_utils.py +1051 -0
  78. onnxruntime/quantization/quantize.py +953 -0
  79. onnxruntime/quantization/registry.py +110 -0
  80. onnxruntime/quantization/shape_inference.py +204 -0
  81. onnxruntime/quantization/static_quantize_runner.py +256 -0
  82. onnxruntime/quantization/tensor_quant_overrides.py +520 -0
  83. onnxruntime/tools/__init__.py +10 -0
  84. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  85. onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
  86. onnxruntime/tools/file_utils.py +47 -0
  87. onnxruntime/tools/logger.py +11 -0
  88. onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
  89. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  90. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
  91. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  92. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  93. onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
  94. onnxruntime/tools/offline_tuning.py +169 -0
  95. onnxruntime/tools/onnx_model_utils.py +416 -0
  96. onnxruntime/tools/onnx_randomizer.py +85 -0
  97. onnxruntime/tools/onnxruntime_test.py +164 -0
  98. onnxruntime/tools/optimize_onnx_model.py +56 -0
  99. onnxruntime/tools/ort_format_model/__init__.py +27 -0
  100. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  140. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  141. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  142. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  143. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  144. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  145. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  146. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  147. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  148. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  149. onnxruntime/tools/ort_format_model/types.py +85 -0
  150. onnxruntime/tools/ort_format_model/utils.py +61 -0
  151. onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
  152. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  153. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  154. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  155. onnxruntime/tools/qnn/add_trans_cast.py +292 -0
  156. onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
  157. onnxruntime/tools/qnn/preprocess.py +165 -0
  158. onnxruntime/tools/reduced_build_config_parser.py +203 -0
  159. onnxruntime/tools/remove_initializer_from_input.py +37 -0
  160. onnxruntime/tools/symbolic_shape_infer.py +3094 -0
  161. onnxruntime/tools/update_onnx_opset.py +31 -0
  162. onnxruntime/transformers/__init__.py +8 -0
  163. onnxruntime/transformers/affinity_helper.py +40 -0
  164. onnxruntime/transformers/benchmark.py +942 -0
  165. onnxruntime/transformers/benchmark_helper.py +643 -0
  166. onnxruntime/transformers/bert_perf_test.py +629 -0
  167. onnxruntime/transformers/bert_test_data.py +641 -0
  168. onnxruntime/transformers/compare_bert_results.py +256 -0
  169. onnxruntime/transformers/constants.py +47 -0
  170. onnxruntime/transformers/convert_generation.py +3605 -0
  171. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  172. onnxruntime/transformers/convert_to_packing_mode.py +385 -0
  173. onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
  174. onnxruntime/transformers/float16.py +501 -0
  175. onnxruntime/transformers/fusion_attention.py +1189 -0
  176. onnxruntime/transformers/fusion_attention_clip.py +340 -0
  177. onnxruntime/transformers/fusion_attention_sam2.py +533 -0
  178. onnxruntime/transformers/fusion_attention_unet.py +1307 -0
  179. onnxruntime/transformers/fusion_attention_vae.py +300 -0
  180. onnxruntime/transformers/fusion_bart_attention.py +435 -0
  181. onnxruntime/transformers/fusion_base.py +141 -0
  182. onnxruntime/transformers/fusion_bias_add.py +57 -0
  183. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  184. onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
  185. onnxruntime/transformers/fusion_conformer_attention.py +222 -0
  186. onnxruntime/transformers/fusion_constant_fold.py +144 -0
  187. onnxruntime/transformers/fusion_embedlayer.py +810 -0
  188. onnxruntime/transformers/fusion_fastgelu.py +492 -0
  189. onnxruntime/transformers/fusion_gelu.py +258 -0
  190. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  191. onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
  192. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  193. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  194. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  195. onnxruntime/transformers/fusion_group_norm.py +180 -0
  196. onnxruntime/transformers/fusion_layernorm.py +489 -0
  197. onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
  198. onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
  199. onnxruntime/transformers/fusion_options.py +340 -0
  200. onnxruntime/transformers/fusion_qordered_attention.py +420 -0
  201. onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
  202. onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
  203. onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
  204. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  205. onnxruntime/transformers/fusion_reshape.py +173 -0
  206. onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
  207. onnxruntime/transformers/fusion_shape.py +109 -0
  208. onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
  209. onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
  210. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  211. onnxruntime/transformers/fusion_transpose.py +167 -0
  212. onnxruntime/transformers/fusion_utils.py +321 -0
  213. onnxruntime/transformers/huggingface_models.py +74 -0
  214. onnxruntime/transformers/import_utils.py +20 -0
  215. onnxruntime/transformers/io_binding_helper.py +487 -0
  216. onnxruntime/transformers/large_model_exporter.py +395 -0
  217. onnxruntime/transformers/machine_info.py +230 -0
  218. onnxruntime/transformers/metrics.py +163 -0
  219. onnxruntime/transformers/models/bart/__init__.py +12 -0
  220. onnxruntime/transformers/models/bart/export.py +98 -0
  221. onnxruntime/transformers/models/bert/__init__.py +12 -0
  222. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  223. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  224. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  225. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
  226. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
  227. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  228. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  229. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  230. onnxruntime/transformers/models/llama/__init__.py +12 -0
  231. onnxruntime/transformers/models/llama/benchmark.py +700 -0
  232. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  233. onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
  234. onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
  235. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  236. onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
  237. onnxruntime/transformers/models/llama/llama_parity.py +343 -0
  238. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  239. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  240. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  241. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  242. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  243. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  244. onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
  245. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  246. onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
  247. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  248. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  249. onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
  250. onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
  251. onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
  252. onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
  253. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  254. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  255. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  256. onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
  257. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
  258. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  259. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  260. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
  261. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  262. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
  263. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
  264. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  265. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
  266. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
  267. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
  268. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
  269. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  270. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  271. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  272. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
  273. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  274. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  275. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  276. onnxruntime/transformers/models/t5/__init__.py +12 -0
  277. onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
  278. onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
  279. onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
  280. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
  281. onnxruntime/transformers/models/t5/t5_helper.py +302 -0
  282. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  283. onnxruntime/transformers/models/whisper/benchmark.py +585 -0
  284. onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
  285. onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
  286. onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
  287. onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
  288. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  289. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
  290. onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
  291. onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
  292. onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
  293. onnxruntime/transformers/onnx_exporter.py +719 -0
  294. onnxruntime/transformers/onnx_model.py +1636 -0
  295. onnxruntime/transformers/onnx_model_bart.py +141 -0
  296. onnxruntime/transformers/onnx_model_bert.py +488 -0
  297. onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
  298. onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
  299. onnxruntime/transformers/onnx_model_clip.py +42 -0
  300. onnxruntime/transformers/onnx_model_conformer.py +32 -0
  301. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  302. onnxruntime/transformers/onnx_model_mmdit.py +112 -0
  303. onnxruntime/transformers/onnx_model_phi.py +929 -0
  304. onnxruntime/transformers/onnx_model_sam2.py +137 -0
  305. onnxruntime/transformers/onnx_model_t5.py +985 -0
  306. onnxruntime/transformers/onnx_model_tnlr.py +226 -0
  307. onnxruntime/transformers/onnx_model_unet.py +258 -0
  308. onnxruntime/transformers/onnx_model_vae.py +42 -0
  309. onnxruntime/transformers/onnx_utils.py +55 -0
  310. onnxruntime/transformers/optimizer.py +620 -0
  311. onnxruntime/transformers/past_helper.py +149 -0
  312. onnxruntime/transformers/profile_result_processor.py +358 -0
  313. onnxruntime/transformers/profiler.py +434 -0
  314. onnxruntime/transformers/quantize_helper.py +76 -0
  315. onnxruntime/transformers/shape_infer_helper.py +121 -0
  316. onnxruntime/transformers/shape_optimizer.py +400 -0
  317. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  318. onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
  319. onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
  320. onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
  321. onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
  322. onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,3094 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ # -*- coding: UTF-8 -*-
5
+ import argparse
6
+ import logging
7
+
8
+ import numpy as np
9
+ import onnx
10
+ import sympy
11
+ from onnx import helper, numpy_helper, shape_inference
12
+ from packaging import version
13
+
14
+ assert version.parse(onnx.__version__) >= version.parse("1.8.0")
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ def get_attribute(node, attr_name, default_value=None):
20
+ found = [attr for attr in node.attribute if attr.name == attr_name]
21
+ if found:
22
+ return helper.get_attribute_value(found[0])
23
+ return default_value
24
+
25
+
26
+ def get_dim_from_proto(dim):
27
+ return getattr(dim, dim.WhichOneof("value")) if type(dim.WhichOneof("value")) is str else None
28
+
29
+
30
+ def is_sequence(type_proto):
31
+ cls_type = type_proto.WhichOneof("value")
32
+ assert cls_type in ["tensor_type", "sequence_type"]
33
+ return cls_type == "sequence_type"
34
+
35
+
36
+ def get_shape_from_type_proto(type_proto):
37
+ assert not is_sequence(type_proto)
38
+ if type_proto.tensor_type.HasField("shape"):
39
+ return [get_dim_from_proto(d) for d in type_proto.tensor_type.shape.dim]
40
+ else:
41
+ return None # note no shape is different from shape without dim (scalar)
42
+
43
+
44
+ def get_elem_type_from_type_proto(type_proto):
45
+ if is_sequence(type_proto):
46
+ return type_proto.sequence_type.elem_type.tensor_type.elem_type
47
+ else:
48
+ return type_proto.tensor_type.elem_type
49
+
50
+
51
+ def get_shape_from_value_info(vi):
52
+ cls_type = vi.type.WhichOneof("value")
53
+ if cls_type is None:
54
+ return None
55
+ if is_sequence(vi.type):
56
+ if vi.type.sequence_type.elem_type.WhichOneof("value") == "tensor_type":
57
+ return get_shape_from_type_proto(vi.type.sequence_type.elem_type)
58
+ else:
59
+ return None
60
+ else:
61
+ return get_shape_from_type_proto(vi.type)
62
+
63
+
64
+ def make_named_value_info(name):
65
+ vi = onnx.ValueInfoProto()
66
+ vi.name = name
67
+ return vi
68
+
69
+
70
+ def get_shape_from_sympy_shape(sympy_shape):
71
+ return [None if i is None else (int(i) if is_literal(i) else str(i)) for i in sympy_shape]
72
+
73
+
74
+ def is_literal(dim):
75
+ return type(dim) in [int, np.int64, np.int32, sympy.Integer] or (hasattr(dim, "is_number") and dim.is_number)
76
+
77
+
78
+ def handle_negative_axis(axis, rank):
79
+ assert axis < rank and axis >= -rank
80
+ return axis if axis >= 0 else rank + axis
81
+
82
+
83
+ def get_opset(mp, domain=None):
84
+ domain = domain or ["", "onnx", "ai.onnx"]
85
+ if type(domain) != list: # noqa: E721
86
+ domain = [domain]
87
+ for opset in mp.opset_import:
88
+ if opset.domain in domain:
89
+ return opset.version
90
+
91
+ return None
92
+
93
+
94
+ def as_scalar(x):
95
+ if type(x) is list:
96
+ assert len(x) == 1
97
+ return x[0]
98
+ elif type(x) is np.ndarray:
99
+ return x.item()
100
+ else:
101
+ return x
102
+
103
+
104
+ def as_list(x, keep_none):
105
+ if type(x) is list:
106
+ return x
107
+ elif type(x) is np.ndarray:
108
+ return list(x)
109
+ elif keep_none and x is None:
110
+ return None
111
+ else:
112
+ return [x]
113
+
114
+
115
+ def sympy_reduce_product(x):
116
+ if type(x) is list:
117
+ value = sympy.Integer(1)
118
+ for v in x:
119
+ value = value * v
120
+ else:
121
+ value = x
122
+ return value
123
+
124
+
125
+ class SymbolicShapeInference:
126
+ def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix=""):
127
+ self.dispatcher_ = {
128
+ "Add": self._infer_symbolic_compute_ops,
129
+ "AllReduce": self._pass_on_shape_and_type,
130
+ "ArrayFeatureExtractor": self._infer_ArrayFeatureExtractor,
131
+ "AveragePool": self._infer_Pool,
132
+ "BatchNormalization": self._infer_BatchNormalization,
133
+ "Cast": self._infer_Cast,
134
+ "CategoryMapper": self._infer_CategoryMapper,
135
+ "Compress": self._infer_Compress,
136
+ "Concat": self._infer_Concat,
137
+ "ConcatFromSequence": self._infer_ConcatFromSequence,
138
+ "Constant": self._infer_Constant,
139
+ "ConstantOfShape": self._infer_ConstantOfShape,
140
+ "Conv": self._infer_Conv,
141
+ "CumSum": self._pass_on_shape_and_type,
142
+ "Div": self._infer_symbolic_compute_ops,
143
+ "Einsum": self._infer_Einsum,
144
+ "Expand": self._infer_Expand,
145
+ "Equal": self._infer_symbolic_compute_ops,
146
+ "Floor": self._infer_symbolic_compute_ops,
147
+ "Gather": self._infer_Gather,
148
+ "GatherElements": self._infer_GatherElements,
149
+ "GatherND": self._infer_GatherND,
150
+ "Identity": self._pass_on_shape_and_type,
151
+ "If": self._infer_If,
152
+ "Loop": self._infer_Loop,
153
+ "MatMul": self._infer_MatMul,
154
+ "MatMulInteger16": self._infer_MatMulInteger,
155
+ "MaxPool": self._infer_Pool,
156
+ "Max": self._infer_symbolic_compute_ops,
157
+ "MemcpyFromHost": self._pass_on_shape_and_type,
158
+ "MemcpyToHost": self._pass_on_shape_and_type,
159
+ "Min": self._infer_symbolic_compute_ops,
160
+ "MoE": self._pass_on_shape_and_type,
161
+ "Mul": self._infer_symbolic_compute_ops,
162
+ "NonMaxSuppression": self._infer_NonMaxSuppression,
163
+ "NonZero": self._infer_NonZero,
164
+ "OneHot": self._infer_OneHot,
165
+ "Pad": self._infer_Pad,
166
+ "Range": self._infer_Range,
167
+ "Reciprocal": self._pass_on_shape_and_type,
168
+ "ReduceSum": self._infer_ReduceSum,
169
+ "ReduceMean": self._infer_ReduceMean,
170
+ "ReduceProd": self._infer_ReduceProd,
171
+ "Reshape": self._infer_Reshape,
172
+ "Resize": self._infer_Resize,
173
+ "Round": self._pass_on_shape_and_type,
174
+ "Scan": self._infer_Scan,
175
+ "ScatterElements": self._infer_ScatterElements,
176
+ "SequenceAt": self._infer_SequenceAt,
177
+ "SequenceInsert": self._infer_SequenceInsert,
178
+ "Shape": self._infer_Shape,
179
+ "Size": self._infer_Size,
180
+ "Slice": self._infer_Slice,
181
+ "SoftmaxCrossEntropyLoss": self._infer_SoftmaxCrossEntropyLoss,
182
+ "SoftmaxCrossEntropyLossInternal": self._infer_SoftmaxCrossEntropyLoss,
183
+ "NegativeLogLikelihoodLossInternal": self._infer_SoftmaxCrossEntropyLoss,
184
+ "Split": self._infer_Split,
185
+ "SplitToSequence": self._infer_SplitToSequence,
186
+ "Squeeze": self._infer_Squeeze,
187
+ "Sub": self._infer_symbolic_compute_ops,
188
+ "Tile": self._infer_Tile,
189
+ "TopK": self._infer_TopK,
190
+ "Transpose": self._infer_Transpose,
191
+ "Unsqueeze": self._infer_Unsqueeze,
192
+ "Where": self._infer_symbolic_compute_ops,
193
+ "ZipMap": self._infer_ZipMap,
194
+ "Neg": self._infer_symbolic_compute_ops,
195
+ # contrib ops:
196
+ "Attention": self._infer_Attention,
197
+ "BiasAdd": self._infer_BiasAdd,
198
+ "BiasGelu": self._infer_BiasGelu,
199
+ "BiasSplitGelu": self._infer_BiasSplitGelu,
200
+ "DecoderMaskedMultiHeadAttention": self._infer_DecoderMaskedMultiHeadAttention,
201
+ "DequantizeLinear": self._infer_DequantizeLinear,
202
+ "DynamicTimeWarping": self._infer_DynamicTimeWarping,
203
+ "EmbedLayerNormalization": self._infer_EmbedLayerNormalization,
204
+ "FastGelu": self._infer_FastGelu,
205
+ "GatedRelativePositionBias": self._infer_GatedRelativePositionBias,
206
+ "GatherBlockQuantized": self._infer_Gather,
207
+ "Gelu": self._infer_Gelu,
208
+ "GemmFastGelu": self._infer_GemmFastGelu,
209
+ "GemmFloat8": self._infer_GemmFloat8,
210
+ "GroupNorm": self._infer_GroupNorm,
211
+ "GroupNormalization": self._infer_GroupNorm,
212
+ "GroupQueryAttention": self._infer_GroupQueryAttention,
213
+ "LayerNormalization": self._infer_LayerNormalization,
214
+ "LongformerAttention": self._infer_LongformerAttention,
215
+ "MatMulNBits": self._infer_MatMulNBits,
216
+ "MultiHeadAttention": self._infer_MultiHeadAttention,
217
+ "NhwcConv": self._infer_NhwcConv,
218
+ "PackedAttention": self._infer_PackedAttention,
219
+ "PackedMultiHeadAttention": self._infer_PackedMultiHeadAttention,
220
+ "PagedAttention": self._infer_PagedAttention,
221
+ "PythonOp": self._infer_PythonOp,
222
+ "QLinearAdd": self._infer_QLinearBinary,
223
+ "QLinearMul": self._infer_QLinearBinary,
224
+ "QuantizeLinear": self._infer_QuantizeLinear,
225
+ "QuickGelu": self._infer_FastGelu,
226
+ "RelativePositionBias": self._infer_RelativePositionBias,
227
+ "RemovePadding": self._infer_RemovePadding,
228
+ "RestorePadding": self._infer_RestorePadding,
229
+ "RotaryEmbedding": self._infer_RotaryEmbedding,
230
+ "SimplifiedLayerNormalization": self._infer_LayerNormalization,
231
+ "SkipGroupNorm": self._infer_SkipGroupNorm,
232
+ "SkipLayerNormalization": self._infer_SkipLayerNormalization,
233
+ "SkipSimplifiedLayerNormalization": self._infer_SkipLayerNormalization,
234
+ "SparseAttention": self._infer_SparseAttention,
235
+ "UnfoldTensor": self._infer_UnfoldTensor,
236
+ }
237
+ self.aten_op_dispatcher_ = {
238
+ "embedding": self._infer_Gather,
239
+ "bitwise_or": self._infer_aten_bitwise_or,
240
+ "diagonal": self._infer_aten_diagonal,
241
+ "max_pool2d_with_indices": self._infer_aten_pool2d,
242
+ "max": self._infer_aten_minmax,
243
+ "min": self._infer_aten_minmax,
244
+ "multinomial": self._infer_aten_multinomial,
245
+ "unfold": self._infer_aten_unfold,
246
+ "argmax": self._infer_aten_argmax,
247
+ "avg_pool2d": self._infer_aten_pool2d,
248
+ "_adaptive_avg_pool2d": self._infer_aten_pool2d,
249
+ "numpy_T": self._infer_Transpose,
250
+ "native_group_norm": self._infer_aten_group_norm,
251
+ "upsample_nearest1d": self._infer_aten_upsample,
252
+ "upsample_nearest2d": self._infer_aten_upsample,
253
+ "upsample_nearest3d": self._infer_aten_upsample,
254
+ "upsample_bicubic2d": self._infer_aten_upsample,
255
+ }
256
+ self.run_ = True
257
+ self.suggested_merge_ = {}
258
+ self.symbolic_dims_ = {}
259
+ self.input_symbols_ = {}
260
+ self.auto_merge_ = auto_merge
261
+ self.guess_output_rank_ = guess_output_rank
262
+ self.verbose_ = verbose
263
+ self.int_max_ = int_max
264
+ self.subgraph_id_ = 0
265
+ self.prefix_ = prefix
266
+
267
+ def _add_suggested_merge(self, symbols, apply=False):
268
+ assert all((type(s) is str and s in self.symbolic_dims_) or is_literal(s) for s in symbols)
269
+ symbols = set(symbols)
270
+ for k, v in self.suggested_merge_.items():
271
+ if k in symbols:
272
+ symbols.remove(k)
273
+ symbols.add(v)
274
+ map_to = None
275
+ # if there is literal, map to it first
276
+ for s in symbols:
277
+ if is_literal(s):
278
+ map_to = s
279
+ break
280
+ # when no literals, map to input symbolic dims, then existing symbolic dims
281
+ if map_to is None:
282
+ for s in symbols:
283
+ if s in self.input_symbols_:
284
+ map_to = s
285
+ break
286
+ if map_to is None:
287
+ for s in symbols:
288
+ if type(self.symbolic_dims_[s]) is sympy.Symbol:
289
+ map_to = s
290
+ break
291
+ # when nothing to map to, use the shorter one
292
+ if map_to is None:
293
+ if self.verbose_ > 0:
294
+ logger.warning("Potential unsafe merge between symbolic expressions: (%s)", ",".join(symbols))
295
+ symbols_list = list(symbols)
296
+ lens = [len(s) for s in symbols_list]
297
+ map_to = symbols_list[lens.index(min(lens))]
298
+ symbols.remove(map_to)
299
+
300
+ for s in symbols:
301
+ if s == map_to:
302
+ continue
303
+ if is_literal(map_to) and is_literal(s):
304
+ assert int(map_to) == int(s)
305
+ self.suggested_merge_[s] = int(map_to) if is_literal(map_to) else map_to
306
+ for k, v in self.suggested_merge_.items():
307
+ if v == s:
308
+ self.suggested_merge_[k] = map_to
309
+ if apply and self.auto_merge_:
310
+ self._apply_suggested_merge()
311
+
312
+ def _apply_suggested_merge(self, graph_input_only=False):
313
+ if not self.suggested_merge_:
314
+ return
315
+ for i in list(self.out_mp_.graph.input) + ([] if graph_input_only else list(self.out_mp_.graph.value_info)):
316
+ for d in i.type.tensor_type.shape.dim:
317
+ if d.dim_param in self.suggested_merge_:
318
+ v = self.suggested_merge_[d.dim_param]
319
+ if is_literal(v):
320
+ d.dim_value = int(v)
321
+ else:
322
+ d.dim_param = v
323
+
324
+ def _preprocess(self, in_mp):
325
+ self.out_mp_ = onnx.ModelProto()
326
+ self.out_mp_.CopyFrom(in_mp)
327
+ self.graph_inputs_ = {i.name: i for i in list(self.out_mp_.graph.input)}
328
+ self.initializers_ = {i.name: i for i in self.out_mp_.graph.initializer}
329
+ self.known_vi_ = {i.name: i for i in list(self.out_mp_.graph.input)}
330
+ self.known_vi_.update(
331
+ {
332
+ i.name: helper.make_tensor_value_info(i.name, i.data_type, list(i.dims))
333
+ for i in self.out_mp_.graph.initializer
334
+ }
335
+ )
336
+
337
+ def _merge_symbols(self, dims):
338
+ if not all(type(d) is str for d in dims):
339
+ if self.auto_merge_:
340
+ unique_dims = list(set(dims))
341
+ is_int = [is_literal(d) for d in unique_dims]
342
+ assert sum(is_int) <= 1 # if there are more than 1 unique ints, something is wrong
343
+ if sum(is_int) == 1:
344
+ int_dim = is_int.index(1)
345
+ if self.verbose_ > 0:
346
+ logger.debug(
347
+ f"dim {unique_dims[:int_dim] + unique_dims[int_dim + 1 :]} has been merged with value {unique_dims[int_dim]}"
348
+ )
349
+ self._check_merged_dims(unique_dims, allow_broadcast=False)
350
+ return unique_dims[int_dim]
351
+ else:
352
+ if self.verbose_ > 0:
353
+ logger.debug(f"dim {unique_dims[1:]} has been merged with dim {unique_dims[0]}")
354
+ return dims[0]
355
+ else:
356
+ return None
357
+ if all(d == dims[0] for d in dims):
358
+ return dims[0]
359
+ merged = [self.suggested_merge_.get(d, d) for d in dims]
360
+ if all(d == merged[0] for d in merged):
361
+ assert merged[0] in self.symbolic_dims_
362
+ return merged[0]
363
+ else:
364
+ return None
365
+
366
+ # broadcast from right to left, and merge symbolic dims if needed
367
+ def _broadcast_shapes(self, shape1, shape2):
368
+ new_shape = []
369
+ rank1 = len(shape1)
370
+ rank2 = len(shape2)
371
+ new_rank = max(rank1, rank2)
372
+ for i in range(new_rank):
373
+ dim1 = shape1[rank1 - 1 - i] if i < rank1 else 1
374
+ dim2 = shape2[rank2 - 1 - i] if i < rank2 else 1
375
+ if dim1 == 1 or dim1 == dim2:
376
+ new_dim = dim2
377
+ elif dim2 == 1:
378
+ new_dim = dim1
379
+ else:
380
+ new_dim = self._merge_symbols([dim1, dim2])
381
+ if not new_dim:
382
+ # warning about unsupported broadcast when not auto merge
383
+ # note that auto merge has the risk of incorrectly merge symbols while one of them being 1
384
+ # for example, 'a' = 1, 'b' = 5 at runtime is valid broadcasting, but with auto merge 'a' == 'b'
385
+ if self.auto_merge_:
386
+ self._add_suggested_merge([dim1, dim2], apply=True)
387
+ else:
388
+ logger.warning("unsupported broadcast between " + str(dim1) + " " + str(dim2)) # noqa: G003
389
+ new_shape = [new_dim, *new_shape]
390
+ return new_shape
391
+
392
+ def _get_shape(self, node, idx):
393
+ name = node.input[idx]
394
+ if name in self.known_vi_:
395
+ vi = self.known_vi_[name]
396
+ return get_shape_from_value_info(vi)
397
+ else:
398
+ assert name in self.initializers_
399
+ return list(self.initializers_[name].dims)
400
+
401
+ def _try_get_shape(self, node, idx):
402
+ if idx > len(node.input) - 1:
403
+ return None
404
+ name = node.input[idx]
405
+ if name in self.known_vi_:
406
+ vi = self.known_vi_[name]
407
+ return get_shape_from_value_info(vi)
408
+ if name in self.initializers_:
409
+ return list(self.initializers_[name].dims)
410
+ return None
411
+
412
+ def _get_shape_rank(self, node, idx):
413
+ return len(self._get_shape(node, idx))
414
+
415
+ def _get_sympy_shape(self, node, idx):
416
+ sympy_shape = []
417
+ for d in self._get_shape(node, idx):
418
+ if type(d) is str:
419
+ sympy_shape.append(
420
+ self.symbolic_dims_[d]
421
+ if d in self.symbolic_dims_
422
+ else sympy.Symbol(d, integer=True, nonnegative=True)
423
+ )
424
+ else:
425
+ assert None is not d
426
+ sympy_shape.append(d)
427
+ return sympy_shape
428
+
429
+ def _get_value(self, node, idx):
430
+ name = node.input[idx]
431
+ assert name in self.sympy_data_ or name in self.initializers_
432
+ return self.sympy_data_[name] if name in self.sympy_data_ else numpy_helper.to_array(self.initializers_[name])
433
+
434
+ def _try_get_value(self, node, idx):
435
+ if idx >= len(node.input):
436
+ return None
437
+ name = node.input[idx]
438
+ if name in self.sympy_data_ or name in self.initializers_:
439
+ return self._get_value(node, idx)
440
+ return None
441
+
442
+ def _update_computed_dims(self, new_sympy_shape):
443
+ for i, new_dim in enumerate(new_sympy_shape):
444
+ if not is_literal(new_dim) and type(new_dim) != str: # noqa: E721
445
+ str_dim = str(new_dim)
446
+ if str_dim in self.suggested_merge_:
447
+ if is_literal(self.suggested_merge_[str_dim]):
448
+ continue # no need to create dim for literals
449
+ new_sympy_shape[i] = self.symbolic_dims_[self.suggested_merge_[str_dim]]
450
+ else:
451
+ # add new_dim if it's a computational expression
452
+ if str(new_dim) not in self.symbolic_dims_:
453
+ self.symbolic_dims_[str(new_dim)] = new_dim
454
+
455
+ def _onnx_infer_single_node(self, node):
456
+ # skip onnx shape inference for some ops, as they are handled in _infer_*
457
+ skip_infer = node.op_type in [
458
+ "If",
459
+ "Loop",
460
+ "Scan",
461
+ "SplitToSequence",
462
+ "ZipMap", # contrib ops
463
+ "Attention",
464
+ "BiasAdd",
465
+ "BiasGelu",
466
+ "BiasSplitGelu",
467
+ "DequantizeLinear",
468
+ "DynamicTimeWarping",
469
+ "EmbedLayerNormalization",
470
+ "FastGelu",
471
+ "GatherBlockQuantized",
472
+ "Gelu",
473
+ "GemmFastGelu",
474
+ "GroupNorm",
475
+ "GroupNormalization",
476
+ "GroupQueryAttention",
477
+ "LayerNormalization",
478
+ "LongformerAttention",
479
+ "MultiHeadAttention",
480
+ "NhwcConv",
481
+ "PackedAttention",
482
+ "PagedAttention",
483
+ "PythonOp",
484
+ "QuantizeLinear",
485
+ "QuickGelu",
486
+ "RelativePositionBias",
487
+ "RemovePadding",
488
+ "RestorePadding",
489
+ "RotaryEmbedding",
490
+ "SimplifiedLayerNormalization",
491
+ "SkipLayerNormalization",
492
+ "SkipSimplifiedLayerNormalization",
493
+ "SparseAttention",
494
+ "SkipGroupNorm",
495
+ "QLinearAdd",
496
+ "QLinearMul",
497
+ ]
498
+
499
+ if not skip_infer:
500
+ # Only pass initializers that satisfy the following condition:
501
+ # (1) Operator need value of some input for shape inference.
502
+ # For example, Unsqueeze in opset 13 uses the axes input to calculate shape of output.
503
+ # (2) opset version >= 9. In older version, initializer is required in graph input by onnx spec.
504
+ # (3) The initializer is not in graph input. The means the node input is "constant" in inference.
505
+ initializers = []
506
+ if (get_opset(self.out_mp_) >= 9) and node.op_type in ["Unsqueeze"]:
507
+ initializers = [
508
+ self.initializers_[name]
509
+ for name in node.input
510
+ if (name in self.initializers_ and name not in self.graph_inputs_)
511
+ ]
512
+
513
+ if node.op_type in [
514
+ "Add",
515
+ "Sub",
516
+ "Mul",
517
+ "Div",
518
+ "MatMul",
519
+ "MatMulInteger",
520
+ "MatMulInteger16",
521
+ "Where",
522
+ "Sum",
523
+ ]:
524
+ if node.output[0] in self.known_vi_:
525
+ vi = self.known_vi_[node.output[0]]
526
+ out_rank = len(get_shape_from_type_proto(vi.type))
527
+ in_shapes = [self._get_shape(node, i) for i in range(len(node.input))]
528
+ for d in range(
529
+ out_rank - (2 if node.op_type in ["MatMul", "MatMulInteger", "MatMulInteger16"] else 0)
530
+ ):
531
+ in_dims = [s[len(s) - out_rank + d] for s in in_shapes if len(s) + d >= out_rank]
532
+ if len(in_dims) > 1:
533
+ self._check_merged_dims(in_dims, allow_broadcast=True)
534
+
535
+ # run single node inference with self.known_vi_ shapes
536
+ tmp_graph = helper.make_graph(
537
+ [node],
538
+ "tmp",
539
+ [self.known_vi_[i] for i in node.input if i],
540
+ [make_named_value_info(i) for i in node.output],
541
+ initializers,
542
+ )
543
+
544
+ self.tmp_mp_.graph.CopyFrom(tmp_graph)
545
+
546
+ self.tmp_mp_ = shape_inference.infer_shapes(self.tmp_mp_)
547
+
548
+ for i_o in range(len(node.output)):
549
+ o = node.output[i_o]
550
+ if o: # skip optional output
551
+ vi = self.out_mp_.graph.value_info.add()
552
+ if not skip_infer:
553
+ vi.CopyFrom(self.tmp_mp_.graph.output[i_o])
554
+ else:
555
+ vi.name = o
556
+ self.known_vi_[o] = vi
557
+
558
+ def _onnx_infer_subgraph(self, node, subgraph, use_node_input=True, inc_subgraph_id=True):
559
+ if self.verbose_ > 2:
560
+ logger.debug(f"Inferencing subgraph of node {node.name} with output({node.output[0]}...): {node.op_type}")
561
+ # node inputs are not passed directly to the subgraph
562
+ # it's up to the node dispatcher to prepare subgraph input
563
+ # for example, with Scan/Loop, subgraph input shape would be trimmed from node input shape
564
+ # besides, inputs in subgraph could shadow implicit inputs
565
+ subgraph_inputs = {i.name for i in list(subgraph.initializer) + list(subgraph.input)}
566
+ subgraph_implicit_input = {name for name in self.known_vi_ if name not in subgraph_inputs}
567
+ tmp_graph = helper.make_graph(
568
+ list(subgraph.node),
569
+ "tmp",
570
+ list(subgraph.input) + [self.known_vi_[i] for i in subgraph_implicit_input],
571
+ [make_named_value_info(i.name) for i in subgraph.output],
572
+ )
573
+ tmp_graph.initializer.extend([i for i in self.out_mp_.graph.initializer if i.name in subgraph_implicit_input])
574
+ tmp_graph.initializer.extend(subgraph.initializer)
575
+ self.tmp_mp_.graph.CopyFrom(tmp_graph)
576
+
577
+ symbolic_shape_inference = SymbolicShapeInference(
578
+ self.int_max_,
579
+ self.auto_merge_,
580
+ self.guess_output_rank_,
581
+ self.verbose_,
582
+ prefix=self.prefix_ + "_" + str(self.subgraph_id_),
583
+ )
584
+ if inc_subgraph_id:
585
+ self.subgraph_id_ += 1
586
+
587
+ symbolic_shape_inference._preprocess(self.tmp_mp_)
588
+ symbolic_shape_inference.suggested_merge_ = self.suggested_merge_.copy()
589
+ while symbolic_shape_inference.run_:
590
+ symbolic_shape_inference._infer_impl(self.sympy_data_.copy())
591
+ symbolic_shape_inference._update_output_from_vi()
592
+ if use_node_input:
593
+ # if subgraph uses node input, it needs to update to merged dims
594
+ subgraph.ClearField("input")
595
+ subgraph.input.extend(symbolic_shape_inference.out_mp_.graph.input[: len(node.input)])
596
+ subgraph.ClearField("output")
597
+ subgraph.output.extend(symbolic_shape_inference.out_mp_.graph.output)
598
+ subgraph.ClearField("value_info")
599
+ subgraph.value_info.extend(symbolic_shape_inference.out_mp_.graph.value_info)
600
+ subgraph.ClearField("node")
601
+ subgraph.node.extend(symbolic_shape_inference.out_mp_.graph.node)
602
+ # for new symbolic dims from subgraph output, add to main graph symbolic dims
603
+ subgraph_shapes = [get_shape_from_value_info(o) for o in symbolic_shape_inference.out_mp_.graph.output]
604
+ subgraph_new_symbolic_dims = {
605
+ d for s in subgraph_shapes if s for d in s if type(d) is str and d not in self.symbolic_dims_
606
+ }
607
+ new_dims = {}
608
+ for d in subgraph_new_symbolic_dims:
609
+ assert d in symbolic_shape_inference.symbolic_dims_
610
+ new_dims[d] = symbolic_shape_inference.symbolic_dims_[d]
611
+ self.symbolic_dims_.update(new_dims)
612
+ return symbolic_shape_inference
613
+
614
+ def _get_int_or_float_values(self, node, broadcast=False, allow_float_values=False):
615
+ def int_or_float(value, allow_float_values):
616
+ # If casting into int has precision loss: keep float output
617
+ if allow_float_values and value % 1 != 0:
618
+ return value
619
+ return int(value)
620
+
621
+ values = [self._try_get_value(node, i) for i in range(len(node.input))]
622
+ if all(v is not None for v in values):
623
+ # some shape compute is in floating point, cast to int for sympy
624
+ for i, v in enumerate(values):
625
+ if type(v) is not np.ndarray:
626
+ continue
627
+ if len(v.shape) > 1:
628
+ new_v = None # ignore value for rank > 1
629
+ elif len(v.shape) == 0:
630
+ new_v = int_or_float(v.item(), allow_float_values)
631
+ else:
632
+ assert len(v.shape) == 1
633
+ new_v = [int_or_float(vv, allow_float_values) for vv in v]
634
+ values[i] = new_v
635
+ values_len = [len(v) if isinstance(v, list) else 0 for v in values]
636
+ max_len = max(values_len)
637
+ if max_len >= 1 and broadcast:
638
+ # broadcast
639
+ for i, v in enumerate(values):
640
+ if v is None:
641
+ continue # don't broadcast if value is unknown
642
+ if isinstance(v, list):
643
+ if len(v) < max_len:
644
+ values[i] = v * max_len
645
+ else:
646
+ assert len(v) == max_len
647
+ else:
648
+ values[i] = [v] * max_len
649
+ return values
650
+
651
+ def _compute_on_sympy_data(self, node, op_func):
652
+ assert len(node.output) == 1
653
+
654
+ # Before mul & div operations
655
+ # cast inputs into interger might lose decimal part and reduce precision
656
+ # keep them as float, finish the operation, then cast the result into integer
657
+ if node.op_type in ["Mul", "Div"]:
658
+ values = self._get_int_or_float_values(node, broadcast=True, allow_float_values=True)
659
+ else:
660
+ values = self._get_int_or_float_values(node, broadcast=True)
661
+
662
+ if all(v is not None for v in values):
663
+ is_list = [isinstance(v, list) for v in values]
664
+ as_list = any(is_list)
665
+ if as_list:
666
+ self.sympy_data_[node.output[0]] = [op_func(vs) for vs in zip(*values, strict=False)]
667
+ else:
668
+ self.sympy_data_[node.output[0]] = op_func(values)
669
+
670
+ def _pass_on_sympy_data(self, node):
671
+ assert len(node.input) == 1 or node.op_type in [
672
+ "Reshape",
673
+ "Unsqueeze",
674
+ "Squeeze",
675
+ ]
676
+ self._compute_on_sympy_data(node, lambda x: x[0])
677
+
678
+ def _pass_on_shape_and_type(self, node):
679
+ vi = self.known_vi_[node.output[0]]
680
+ vi.CopyFrom(
681
+ helper.make_tensor_value_info(
682
+ node.output[0],
683
+ get_elem_type_from_type_proto(self.known_vi_[node.input[0]].type),
684
+ self._get_shape(node, 0),
685
+ )
686
+ )
687
+
688
+ def _new_symbolic_dim(self, prefix, dim):
689
+ new_dim = f"{prefix}_d{dim}"
690
+ if new_dim in self.suggested_merge_:
691
+ v = self.suggested_merge_[new_dim]
692
+ new_symbolic_dim = sympy.Integer(int(v)) if is_literal(v) else v
693
+ else:
694
+ new_symbolic_dim = sympy.Symbol(new_dim, integer=True, nonnegative=True)
695
+ self.symbolic_dims_[new_dim] = new_symbolic_dim
696
+ return new_symbolic_dim
697
+
698
+ def _new_symbolic_dim_from_output(self, node, out_idx=0, dim=0):
699
+ return self._new_symbolic_dim(
700
+ f"{node.op_type}{self.prefix_}_{list(self.out_mp_.graph.node).index(node)}_o{out_idx}_",
701
+ dim,
702
+ )
703
+
704
+ def _new_symbolic_shape(self, rank, node, out_idx=0):
705
+ return [self._new_symbolic_dim_from_output(node, out_idx, i) for i in range(rank)]
706
+
707
+ def _compute_conv_pool_shape(self, node, channels_last=False):
708
+ sympy_shape = self._get_sympy_shape(node, 0)
709
+ if len(node.input) > 1:
710
+ W_shape = self._get_sympy_shape(node, 1) # noqa: N806
711
+ rank = len(W_shape) - 2 # number of spatial axes
712
+ kernel_shape = W_shape[-rank - 1 : -1] if channels_last else W_shape[-rank:]
713
+ sympy_shape[3 if channels_last else 1] = W_shape[0]
714
+ else:
715
+ W_shape = None # noqa: N806
716
+ kernel_shape = get_attribute(node, "kernel_shape")
717
+ rank = len(kernel_shape)
718
+
719
+ assert len(sympy_shape) == rank + 2
720
+
721
+ # only need to symbolic shape inference if input has symbolic dims in spatial axes
722
+ spatial_shape = sympy_shape[-rank - 1 : -1] if channels_last else sympy_shape[-rank:]
723
+ is_symbolic_dims = [not is_literal(i) for i in spatial_shape]
724
+
725
+ if not any(is_symbolic_dims):
726
+ shape = get_shape_from_value_info(self.known_vi_[node.output[0]])
727
+ if len(shape) > 0:
728
+ assert len(sympy_shape) == len(shape)
729
+ if channels_last:
730
+ sympy_shape[-rank - 1 : -1] = [sympy.Integer(d) for d in shape[-rank - 1 : -1]]
731
+ else:
732
+ sympy_shape[-rank:] = [sympy.Integer(d) for d in shape[-rank:]]
733
+ return sympy_shape
734
+
735
+ dilations = get_attribute(node, "dilations", [1] * rank)
736
+ strides = get_attribute(node, "strides", [1] * rank)
737
+ effective_kernel_shape = [(k - 1) * d + 1 for k, d in zip(kernel_shape, dilations, strict=False)]
738
+ pads = get_attribute(node, "pads")
739
+ if pads is None:
740
+ pads = [0] * (2 * rank)
741
+ auto_pad = get_attribute(node, "auto_pad", b"NOTSET").decode("utf-8")
742
+ if auto_pad != "VALID" and auto_pad != "NOTSET":
743
+ try:
744
+ residual = [sympy.Mod(d, s) for d, s in zip(sympy_shape[-rank:], strides, strict=False)]
745
+ total_pads = [
746
+ max(0, (k - s) if r == 0 else (k - r))
747
+ for k, s, r in zip(effective_kernel_shape, strides, residual, strict=False)
748
+ ]
749
+ except TypeError: # sympy may throw TypeError: cannot determine truth value of Relational
750
+ total_pads = [
751
+ max(0, (k - s)) for k, s in zip(effective_kernel_shape, strides, strict=False)
752
+ ] # assuming no residual if sympy throws error
753
+ elif auto_pad == "VALID":
754
+ total_pads = []
755
+ else:
756
+ total_pads = [0] * rank
757
+ else:
758
+ assert len(pads) == 2 * rank
759
+ total_pads = [p1 + p2 for p1, p2 in zip(pads[:rank], pads[rank:], strict=False)]
760
+
761
+ ceil_mode = get_attribute(node, "ceil_mode", 0)
762
+ for i in range(rank):
763
+ effective_input_size = sympy_shape[-rank + i + (-1 if channels_last else 0)]
764
+ if len(total_pads) > 0:
765
+ effective_input_size = effective_input_size + total_pads[i]
766
+ if ceil_mode:
767
+ strided_kernel_positions = sympy.ceiling(
768
+ (effective_input_size - effective_kernel_shape[i]) / strides[i]
769
+ )
770
+ else:
771
+ strided_kernel_positions = (effective_input_size - effective_kernel_shape[i]) // strides[i]
772
+ sympy_shape[-rank + i + (-1 if channels_last else 0)] = strided_kernel_positions + 1
773
+ return sympy_shape
774
+
775
+ def _check_merged_dims(self, dims, allow_broadcast=True):
776
+ if allow_broadcast:
777
+ dims = [d for d in dims if not (is_literal(d) and int(d) <= 1)]
778
+ if not all(d == dims[0] for d in dims):
779
+ self._add_suggested_merge(dims, apply=True)
780
+
781
+ def _compute_matmul_shape(self, node, output_dtype=None):
782
+ lhs_shape = self._get_shape(node, 0)
783
+ rhs_shape = self._get_shape(node, 1)
784
+ lhs_rank = len(lhs_shape)
785
+ rhs_rank = len(rhs_shape)
786
+ lhs_reduce_dim = 0
787
+ rhs_reduce_dim = 0
788
+ assert lhs_rank > 0 and rhs_rank > 0
789
+ if lhs_rank == 1 and rhs_rank == 1:
790
+ new_shape = []
791
+ elif lhs_rank == 1:
792
+ rhs_reduce_dim = -2
793
+ new_shape = [*rhs_shape[:rhs_reduce_dim], rhs_shape[-1]]
794
+ elif rhs_rank == 1:
795
+ lhs_reduce_dim = -1
796
+ new_shape = lhs_shape[:lhs_reduce_dim]
797
+ else:
798
+ lhs_reduce_dim = -1
799
+ rhs_reduce_dim = -2
800
+ new_shape = [*self._broadcast_shapes(lhs_shape[:-2], rhs_shape[:-2]), lhs_shape[-2], rhs_shape[-1]]
801
+ # merge reduce dim
802
+ self._check_merged_dims(
803
+ [lhs_shape[lhs_reduce_dim], rhs_shape[rhs_reduce_dim]],
804
+ allow_broadcast=False,
805
+ )
806
+ if output_dtype is None:
807
+ # infer output_dtype from input type when not specified
808
+ output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type
809
+ vi = self.known_vi_[node.output[0]]
810
+ vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, new_shape))
811
+
812
+ def _fuse_tensor_type(self, node, out_idx, dst_type, src_type):
813
+ """
814
+ update dst_tensor_type to be compatible with src_tensor_type when dimension mismatches
815
+ """
816
+ dst_tensor_type = (
817
+ dst_type.sequence_type.elem_type.tensor_type if is_sequence(dst_type) else dst_type.tensor_type
818
+ )
819
+ src_tensor_type = (
820
+ src_type.sequence_type.elem_type.tensor_type if is_sequence(src_type) else src_type.tensor_type
821
+ )
822
+ if dst_tensor_type.elem_type != src_tensor_type.elem_type:
823
+ node_id = node.name if node.name else node.op_type
824
+ raise ValueError(
825
+ f"For node {node_id}, dst_tensor_type.elem_type != src_tensor_type.elem_type: "
826
+ f"{onnx.onnx_pb.TensorProto.DataType.Name(dst_tensor_type.elem_type)} vs "
827
+ f"{onnx.onnx_pb.TensorProto.DataType.Name(src_tensor_type.elem_type)}"
828
+ )
829
+ if dst_tensor_type.HasField("shape"):
830
+ for di, ds in enumerate(zip(dst_tensor_type.shape.dim, src_tensor_type.shape.dim, strict=False)):
831
+ if ds[0] != ds[1]:
832
+ # create a new symbolic dimension for node/out_idx/mismatch dim id in dst_tensor_type for tensor_type
833
+ # for sequence_type, clear the dimension
834
+ new_dim = onnx.TensorShapeProto.Dimension()
835
+ if not is_sequence(dst_type):
836
+ new_dim.dim_param = str(self._new_symbolic_dim_from_output(node, out_idx, di))
837
+ dst_tensor_type.shape.dim[di].CopyFrom(new_dim)
838
+ else:
839
+ dst_tensor_type.CopyFrom(src_tensor_type)
840
+
841
+ def _infer_ArrayFeatureExtractor(self, node): # noqa: N802
842
+ data_shape = self._get_shape(node, 0)
843
+ indices_shape = self._get_shape(node, 1)
844
+ vi = self.known_vi_[node.output[0]]
845
+ vi.CopyFrom(
846
+ helper.make_tensor_value_info(
847
+ node.output[0],
848
+ self.known_vi_[node.input[0]].type.tensor_type.elem_type,
849
+ data_shape[:-1] + indices_shape,
850
+ )
851
+ )
852
+
853
+ def _infer_symbolic_compute_ops(self, node):
854
+ funcs = {
855
+ "Add": lambda l: l[0] + l[1], # noqa: E741
856
+ "Div": lambda l: ( # noqa: E741
857
+ int(l[0] // l[1]) if isinstance(l[0] // l[1], float) else l[0] // l[1]
858
+ ), # integer div in sympy
859
+ "Equal": lambda l: l[0] == l[1], # noqa: E741
860
+ "Floor": lambda l: sympy.floor(l[0]), # noqa: E741
861
+ "Max": lambda l: ( # noqa: E741
862
+ l[1]
863
+ if is_literal(l[0]) and int(l[0]) < -self.int_max_
864
+ else (l[0] if is_literal(l[1]) and int(l[1]) < -self.int_max_ else sympy.Max(l[0], l[1]))
865
+ ),
866
+ "Min": lambda l: ( # noqa: E741
867
+ l[1]
868
+ if is_literal(l[0]) and int(l[0]) > self.int_max_
869
+ else (l[0] if is_literal(l[1]) and int(l[1]) > self.int_max_ else sympy.Min(l[0], l[1]))
870
+ ),
871
+ "Mul": lambda l: int(l[0] * l[1]) if isinstance(l[0] * l[1], float) else l[0] * l[1], # noqa: E741
872
+ "Sub": lambda l: l[0] - l[1], # noqa: E741
873
+ "Where": lambda l: l[1] if l[0] else l[2], # noqa: E741
874
+ "Neg": lambda l: -l[0], # noqa: E741
875
+ }
876
+ assert node.op_type in funcs
877
+ self._compute_on_sympy_data(node, funcs[node.op_type])
878
+
879
+ def _infer_Cast(self, node): # noqa: N802
880
+ self._pass_on_sympy_data(node)
881
+
882
+ def _infer_CategoryMapper(self, node): # noqa: N802
883
+ input_type = self.known_vi_[node.input[0]].type.tensor_type.elem_type
884
+ if input_type == onnx.TensorProto.STRING:
885
+ output_type = onnx.TensorProto.INT64
886
+ else:
887
+ output_type = onnx.TensorProto.STRING
888
+ vi = self.known_vi_[node.output[0]]
889
+ vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_type, self._get_shape(node, 0)))
890
+
891
+ def _infer_Compress(self, node): # noqa: N802
892
+ input_shape = self._get_shape(node, 0)
893
+ # create a new symbolic dimension for Compress output
894
+ compress_len = str(self._new_symbolic_dim_from_output(node))
895
+ axis = get_attribute(node, "axis")
896
+ if axis is None:
897
+ # when axis is not specified, input is flattened before compress so output is 1D
898
+ output_shape = [compress_len]
899
+ else:
900
+ output_shape = input_shape
901
+ output_shape[handle_negative_axis(axis, len(input_shape))] = compress_len
902
+ vi = self.known_vi_[node.output[0]]
903
+ vi.CopyFrom(
904
+ helper.make_tensor_value_info(
905
+ node.output[0],
906
+ self.known_vi_[node.input[0]].type.tensor_type.elem_type,
907
+ output_shape,
908
+ )
909
+ )
910
+
911
+ def _infer_Concat(self, node): # noqa: N802
912
+ if any(i in self.sympy_data_ or i in self.initializers_ for i in node.input):
913
+ values = self._get_int_or_float_values(node)
914
+ if all(v is not None for v in values):
915
+ assert get_attribute(node, "axis") == 0
916
+ self.sympy_data_[node.output[0]] = []
917
+ for i in range(len(node.input)):
918
+ value = values[i]
919
+ if isinstance(value, list):
920
+ self.sympy_data_[node.output[0]].extend(value)
921
+ else:
922
+ self.sympy_data_[node.output[0]].append(value)
923
+
924
+ sympy_shape = self._get_sympy_shape(node, 0)
925
+ axis = handle_negative_axis(get_attribute(node, "axis"), len(sympy_shape))
926
+ for i_idx in range(1, len(node.input)):
927
+ input_shape = self._get_sympy_shape(node, i_idx)
928
+ if input_shape:
929
+ sympy_shape[axis] = sympy_shape[axis] + input_shape[axis]
930
+ self._update_computed_dims(sympy_shape)
931
+ # merge symbolic dims for non-concat axes
932
+ for d in range(len(sympy_shape)):
933
+ if d == axis:
934
+ continue
935
+ dims = [self._get_shape(node, i_idx)[d] for i_idx in range(len(node.input)) if self._get_shape(node, i_idx)]
936
+ if all(d == dims[0] for d in dims):
937
+ continue
938
+ merged = self._merge_symbols(dims)
939
+ if type(merged) is str:
940
+ sympy_shape[d] = self.symbolic_dims_[merged] if merged else None
941
+ else:
942
+ sympy_shape[d] = merged
943
+ vi = self.known_vi_[node.output[0]]
944
+ vi.CopyFrom(
945
+ helper.make_tensor_value_info(
946
+ node.output[0],
947
+ self.known_vi_[node.input[0]].type.tensor_type.elem_type,
948
+ get_shape_from_sympy_shape(sympy_shape),
949
+ )
950
+ )
951
+
952
+ def _infer_ConcatFromSequence(self, node): # noqa: N802
953
+ seq_shape = self._get_shape(node, 0)
954
+ new_axis = 1 if get_attribute(node, "new_axis") else 0
955
+ axis = handle_negative_axis(get_attribute(node, "axis"), len(seq_shape) + new_axis)
956
+ concat_dim = str(self._new_symbolic_dim_from_output(node, 0, axis))
957
+ new_shape = seq_shape
958
+ if new_axis:
959
+ new_shape = [*seq_shape[:axis], concat_dim, *seq_shape[axis:]]
960
+ else:
961
+ new_shape[axis] = concat_dim
962
+ vi = self.known_vi_[node.output[0]]
963
+ vi.CopyFrom(
964
+ helper.make_tensor_value_info(
965
+ node.output[0],
966
+ self.known_vi_[node.input[0]].type.sequence_type.elem_type.tensor_type.elem_type,
967
+ new_shape,
968
+ )
969
+ )
970
+
971
+ def _infer_Constant(self, node): # noqa: N802
972
+ t = get_attribute(node, "value")
973
+ self.sympy_data_[node.output[0]] = numpy_helper.to_array(t)
974
+
975
+ def _infer_ConstantOfShape(self, node): # noqa: N802
976
+ sympy_shape = self._get_int_or_float_values(node)[0]
977
+ vi = self.known_vi_[node.output[0]]
978
+ if sympy_shape is not None:
979
+ if type(sympy_shape) != list: # noqa: E721
980
+ sympy_shape = [sympy_shape]
981
+ self._update_computed_dims(sympy_shape)
982
+ # update sympy data if output type is int, and shape is known
983
+ if vi.type.tensor_type.elem_type == onnx.TensorProto.INT64 and all(is_literal(x) for x in sympy_shape):
984
+ self.sympy_data_[node.output[0]] = np.ones(
985
+ [int(x) for x in sympy_shape], dtype=np.int64
986
+ ) * numpy_helper.to_array(get_attribute(node, "value", 0))
987
+ else:
988
+ # create new dynamic shape
989
+ # note input0 is a 1D vector of shape, the new symbolic shape has the rank of the shape vector length
990
+ sympy_shape = self._new_symbolic_shape(self._get_shape(node, 0)[0], node)
991
+
992
+ vi.CopyFrom(
993
+ helper.make_tensor_value_info(
994
+ node.output[0],
995
+ vi.type.tensor_type.elem_type,
996
+ get_shape_from_sympy_shape(sympy_shape),
997
+ )
998
+ )
999
+
1000
+ def _infer_Conv(self, node): # noqa: N802
1001
+ sympy_shape = self._compute_conv_pool_shape(node)
1002
+ self._update_computed_dims(sympy_shape)
1003
+ vi = self.known_vi_[node.output[0]]
1004
+ vi.CopyFrom(
1005
+ helper.make_tensor_value_info(
1006
+ node.output[0],
1007
+ vi.type.tensor_type.elem_type,
1008
+ get_shape_from_sympy_shape(sympy_shape),
1009
+ )
1010
+ )
1011
+
1012
+ def _infer_NhwcConv(self, node): # noqa: N802
1013
+ sympy_shape = self._compute_conv_pool_shape(node, channels_last=True)
1014
+ self._update_computed_dims(sympy_shape)
1015
+ vi = self.known_vi_[node.output[0]]
1016
+ vi.CopyFrom(
1017
+ helper.make_tensor_value_info(
1018
+ node.output[0],
1019
+ self.known_vi_[node.input[0]].type.tensor_type.elem_type,
1020
+ get_shape_from_sympy_shape(sympy_shape),
1021
+ )
1022
+ )
1023
+
1024
+ def _infer_DequantizeLinear(self, node): # noqa: N802
1025
+ # Get the output data type from the scale input (index 1, required).
1026
+ output_dtype = self.known_vi_[node.input[1]].type.tensor_type.elem_type
1027
+
1028
+ # Get the output shape from the first input.
1029
+ output_shape = self._get_shape(node, 0)
1030
+
1031
+ vi = self.known_vi_[node.output[0]]
1032
+ vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape))
1033
+
1034
+ def _infer_QuantizeLinear(self, node): # noqa: N802
1035
+ # Get the output data type from the zero-point input (index 2, optional).
1036
+ # Otherwise, default to uint8
1037
+ output_dtype = onnx.TensorProto.UINT8
1038
+ if len(node.input) > 2 and node.input[2]:
1039
+ output_dtype = self.known_vi_[node.input[2]].type.tensor_type.elem_type
1040
+
1041
+ # Get the output shape from the first input.
1042
+ output_shape = self._get_shape(node, 0)
1043
+
1044
+ vi = self.known_vi_[node.output[0]]
1045
+ vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape))
1046
+
1047
+ def _infer_QLinearBinary(self, node): # noqa: N802
1048
+ # Get the output data type from the first input to QLinearAdd / QLinearMul.
1049
+ output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type
1050
+
1051
+ # The inputs are first and fourth operands respectively.
1052
+ input_1_shape = self._get_shape(node, 0)
1053
+ input_2_shape = self._get_shape(node, 3)
1054
+
1055
+ # Compute the broadcasted shape
1056
+ new_shape = self._broadcast_shapes(input_1_shape, input_2_shape)
1057
+
1058
+ vi = self.known_vi_[node.output[0]]
1059
+ vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, new_shape))
1060
+
1061
+ def _infer_Einsum(self, node): # noqa: N802
1062
+ # ref:https://github.com/onnx/onnx/blob/623dfaa0151b2e4ce49779c3ec31cbd78c592b80/onnx/defs/math/defs.cc#L3275
1063
+ equation = get_attribute(node, "equation")
1064
+ equation = equation.replace(b" ", b"")
1065
+ mid_index = equation.find(b"->")
1066
+ left_equation = equation[:mid_index] if mid_index != -1 else equation
1067
+
1068
+ num_operands = 0
1069
+ num_ellipsis = 0
1070
+ num_ellipsis_indices = 0
1071
+
1072
+ letter_to_dim = {}
1073
+
1074
+ terms = left_equation.split(b",")
1075
+ for term in terms:
1076
+ ellipsis_index = term.find(b"...")
1077
+ shape = self._get_shape(node, num_operands)
1078
+ rank = len(shape)
1079
+ if ellipsis_index != -1:
1080
+ if num_ellipsis == 0:
1081
+ num_ellipsis_indices = rank - len(term) + 3
1082
+ num_ellipsis = num_ellipsis + 1
1083
+ for i in range(1, rank + 1):
1084
+ letter = term[-i]
1085
+ if letter != 46: # letter != b'.'
1086
+ dim = shape[-i]
1087
+ if letter not in letter_to_dim:
1088
+ letter_to_dim[letter] = dim
1089
+ elif type(dim) is not sympy.Symbol:
1090
+ letter_to_dim[letter] = dim
1091
+ num_operands = num_operands + 1
1092
+
1093
+ new_sympy_shape = []
1094
+ from collections import OrderedDict # noqa: PLC0415
1095
+
1096
+ num_letter_occurrences = OrderedDict()
1097
+ if mid_index != -1:
1098
+ right_equation = equation[mid_index + 2 :]
1099
+ right_ellipsis_index = right_equation.find(b"...")
1100
+ if right_ellipsis_index != -1:
1101
+ for i in range(num_ellipsis_indices):
1102
+ new_sympy_shape.append(shape[i])
1103
+ for c in right_equation:
1104
+ if c != 46: # c != b'.'
1105
+ new_sympy_shape.append(letter_to_dim[c])
1106
+ else:
1107
+ for i in range(num_ellipsis_indices):
1108
+ new_sympy_shape.append(shape[i])
1109
+ for c in left_equation:
1110
+ if c != 44 and c != 46: # c != b',' and c != b'.':
1111
+ if c in num_letter_occurrences:
1112
+ num_letter_occurrences[c] = num_letter_occurrences[c] + 1
1113
+ else:
1114
+ num_letter_occurrences[c] = 1
1115
+ for key, value in num_letter_occurrences.items():
1116
+ if value == 1:
1117
+ new_sympy_shape.append(letter_to_dim[key])
1118
+
1119
+ output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type
1120
+ vi = self.known_vi_[node.output[0]]
1121
+ vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, new_sympy_shape))
1122
+
1123
+ def _infer_Expand(self, node): # noqa: N802
1124
+ expand_to_shape = as_list(self._try_get_value(node, 1), keep_none=True)
1125
+ if expand_to_shape is not None:
1126
+ # new_shape's dim can come from shape value
1127
+ self._update_computed_dims(expand_to_shape)
1128
+ shape = self._get_shape(node, 0)
1129
+ new_shape = self._broadcast_shapes(shape, get_shape_from_sympy_shape(expand_to_shape))
1130
+ vi = self.known_vi_[node.output[0]]
1131
+ vi.CopyFrom(
1132
+ helper.make_tensor_value_info(
1133
+ node.output[0],
1134
+ self.known_vi_[node.input[0]].type.tensor_type.elem_type,
1135
+ new_shape,
1136
+ )
1137
+ )
1138
+
1139
+ def _infer_Gather(self, node): # noqa: N802
1140
+ data_shape = self._get_shape(node, 0)
1141
+ axis = handle_negative_axis(get_attribute(node, "axis", 0), len(data_shape))
1142
+ indices_shape = self._get_shape(node, 1)
1143
+ vi = self.known_vi_[node.output[0]]
1144
+ if node.op_type == "Gather":
1145
+ elem_type = self.known_vi_[node.input[0]].type.tensor_type.elem_type
1146
+ elif node.op_type == "GatherBlockQuantized":
1147
+ # scales
1148
+ elem_type = self.known_vi_[node.input[2]].type.tensor_type.elem_type
1149
+ else:
1150
+ raise ValueError(f"Unsupported Gather op_type: {node.op_type}")
1151
+ vi.CopyFrom(
1152
+ helper.make_tensor_value_info(
1153
+ node.output[0],
1154
+ elem_type,
1155
+ data_shape[:axis] + indices_shape + data_shape[axis + 1 :],
1156
+ )
1157
+ )
1158
+ # for 1D input, do some sympy compute
1159
+ if node.input[0] in self.sympy_data_ and len(data_shape) == 1 and get_attribute(node, "axis", 0) == 0:
1160
+ idx = self._try_get_value(node, 1)
1161
+ if idx is not None:
1162
+ data = self.sympy_data_[node.input[0]]
1163
+ if type(data) is list:
1164
+ if type(idx) is np.ndarray and len(idx.shape) == 1:
1165
+ self.sympy_data_[node.output[0]] = [data[int(i)] for i in idx]
1166
+ else:
1167
+ self.sympy_data_[node.output[0]] = data[int(idx)]
1168
+ else:
1169
+ assert idx == 0 or idx == -1
1170
+ self.sympy_data_[node.output[0]] = data
1171
+
1172
+ def _infer_GatherElements(self, node): # noqa: N802
1173
+ indices_shape = self._get_shape(node, 1)
1174
+ vi = self.known_vi_[node.output[0]]
1175
+ vi.CopyFrom(
1176
+ helper.make_tensor_value_info(
1177
+ node.output[0],
1178
+ self.known_vi_[node.input[0]].type.tensor_type.elem_type,
1179
+ indices_shape,
1180
+ )
1181
+ )
1182
+
1183
+ def _infer_GatherND(self, node): # noqa: N802
1184
+ data_shape = self._get_shape(node, 0)
1185
+ data_rank = len(data_shape)
1186
+ indices_shape = self._get_shape(node, 1)
1187
+ len(indices_shape)
1188
+ last_index_dimension = indices_shape[-1]
1189
+ assert is_literal(last_index_dimension) and last_index_dimension <= data_rank
1190
+ new_shape = indices_shape[:-1] + data_shape[last_index_dimension:]
1191
+ vi = self.known_vi_[node.output[0]]
1192
+ vi.CopyFrom(
1193
+ helper.make_tensor_value_info(
1194
+ node.output[0],
1195
+ self.known_vi_[node.input[0]].type.tensor_type.elem_type,
1196
+ new_shape,
1197
+ )
1198
+ )
1199
+
1200
+ def _infer_If(self, node): # noqa: N802
1201
+ # special case for constant condition, in case there are mismatching shape from the non-executed branch
1202
+ subgraphs = [
1203
+ get_attribute(node, "then_branch"),
1204
+ get_attribute(node, "else_branch"),
1205
+ ]
1206
+ cond = self._try_get_value(node, 0)
1207
+ if cond is not None:
1208
+ if as_scalar(cond) > 0:
1209
+ subgraphs[1].CopyFrom(subgraphs[0])
1210
+ else:
1211
+ subgraphs[0].CopyFrom(subgraphs[1])
1212
+
1213
+ for i_sub, subgraph in enumerate(subgraphs):
1214
+ subgraph_infer = self._onnx_infer_subgraph(node, subgraph, use_node_input=False)
1215
+ for i_out in range(len(node.output)):
1216
+ vi = self.known_vi_[node.output[i_out]]
1217
+ if i_sub == 0:
1218
+ vi.CopyFrom(subgraph.output[i_out])
1219
+ vi.name = node.output[i_out]
1220
+ else:
1221
+ self._fuse_tensor_type(node, i_out, vi.type, subgraph.output[i_out].type)
1222
+
1223
+ # pass on sympy data from subgraph, if cond is constant
1224
+ if cond is not None and i_sub == (0 if as_scalar(cond) > 0 else 1):
1225
+ if subgraph.output[i_out].name in subgraph_infer.sympy_data_:
1226
+ self.sympy_data_[vi.name] = subgraph_infer.sympy_data_[subgraph.output[i_out].name]
1227
+
1228
+ def _infer_Loop(self, node): # noqa: N802
1229
+ subgraph = get_attribute(node, "body")
1230
+ assert len(subgraph.input) == len(node.input)
1231
+ num_loop_carried = len(node.input) - 2 # minus the length and initial loop condition
1232
+ # when sequence_type is used as loop carried input
1233
+ # needs to run subgraph infer twice if the tensor shape in sequence contains None
1234
+ for i, si in enumerate(subgraph.input):
1235
+ si_name = si.name
1236
+ si.CopyFrom(self.known_vi_[node.input[i]])
1237
+ si.name = si_name
1238
+
1239
+ self._onnx_infer_subgraph(node, subgraph)
1240
+
1241
+ # check subgraph input/output for shape changes in loop carried variables
1242
+ # for tensor_type, create new symbolic dim when changing, i.e., output = Concat(input, a)
1243
+ # for sequence_type, propagate from output to input
1244
+ need_second_infer = False
1245
+ for i_out in range(1, num_loop_carried + 1):
1246
+ so = subgraph.output[i_out]
1247
+ so_shape = get_shape_from_value_info(so)
1248
+ if is_sequence(so.type):
1249
+ if so_shape and None in so_shape:
1250
+ # copy shape from output to input
1251
+ # note that loop input is [loop_len, cond, input_0, input_1, ...]
1252
+ # while loop output is [cond, output_0, output_1, ...]
1253
+ subgraph.input[i_out + 1].type.sequence_type.elem_type.CopyFrom(so.type.sequence_type.elem_type)
1254
+ need_second_infer = True
1255
+ else:
1256
+ si = subgraph.input[i_out + 1]
1257
+ si_shape = get_shape_from_value_info(si)
1258
+ for di, dims in enumerate(zip(si_shape, so_shape, strict=False)):
1259
+ if dims[0] != dims[1]:
1260
+ new_dim = onnx.TensorShapeProto.Dimension()
1261
+ new_dim.dim_param = str(self._new_symbolic_dim_from_output(node, i_out, di))
1262
+ si.type.tensor_type.shape.dim[di].CopyFrom(new_dim)
1263
+ so.type.tensor_type.shape.dim[di].CopyFrom(new_dim)
1264
+ need_second_infer = True
1265
+
1266
+ if need_second_infer:
1267
+ if self.verbose_ > 2:
1268
+ logger.debug(
1269
+ f"Rerun Loop: {node.name}({node.output[0]}...), because of sequence in loop carried variables"
1270
+ )
1271
+ self._onnx_infer_subgraph(node, subgraph, inc_subgraph_id=False)
1272
+
1273
+ # create a new symbolic dimension for iteration dependent dimension
1274
+ loop_iter_dim = str(self._new_symbolic_dim_from_output(node))
1275
+ for i in range(len(node.output)):
1276
+ vi = self.known_vi_[node.output[i]]
1277
+ vi.CopyFrom(subgraph.output[i + 1]) # first subgraph output is condition, not in node output
1278
+ if i >= num_loop_carried:
1279
+ assert not is_sequence(vi.type) # TODO: handle loop accumulation in sequence_type
1280
+ subgraph_vi_dim = subgraph.output[i + 1].type.tensor_type.shape.dim
1281
+ vi.type.tensor_type.shape.ClearField("dim")
1282
+ vi_dim = vi.type.tensor_type.shape.dim
1283
+ vi_dim.add().dim_param = loop_iter_dim
1284
+ vi_dim.extend(list(subgraph_vi_dim))
1285
+ vi.name = node.output[i]
1286
+
1287
+ def _infer_MatMul(self, node): # noqa: N802
1288
+ self._compute_matmul_shape(node)
1289
+
1290
+ def _infer_MatMulInteger(self, node): # noqa: N802
1291
+ self._compute_matmul_shape(node, onnx.TensorProto.INT32)
1292
+
1293
+ def _infer_MatMulNBits(self, node): # noqa: N802
1294
+ lhs_shape = self._get_shape(node, 0)
1295
+ rhs_shape = [get_attribute(node, "K"), get_attribute(node, "N")]
1296
+ lhs_rank = len(lhs_shape)
1297
+ assert lhs_rank > 0
1298
+ if lhs_rank == 1:
1299
+ new_shape = rhs_shape[1:]
1300
+ else:
1301
+ new_shape = lhs_shape[:-1] + rhs_shape[1:]
1302
+ # merge reduce dim
1303
+ self._check_merged_dims(
1304
+ [lhs_shape[-1], rhs_shape[0]],
1305
+ allow_broadcast=False,
1306
+ )
1307
+ # infer output_dtype from input type when not specified
1308
+ output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type
1309
+ vi = self.known_vi_[node.output[0]]
1310
+ vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, new_shape))
1311
+
1312
+ def _infer_NonMaxSuppression(self, node): # noqa: N802
1313
+ selected = str(self._new_symbolic_dim_from_output(node))
1314
+ vi = self.known_vi_[node.output[0]]
1315
+ vi.CopyFrom(helper.make_tensor_value_info(node.output[0], onnx.TensorProto.INT64, [selected, 3]))
1316
+
1317
+ def _infer_NonZero(self, node): # noqa: N802
1318
+ input_rank = self._get_shape_rank(node, 0)
1319
+ # create a new symbolic dimension for NonZero output
1320
+ nz_len = str(self._new_symbolic_dim_from_output(node, 0, 1))
1321
+ vi = self.known_vi_[node.output[0]]
1322
+ vi.CopyFrom(helper.make_tensor_value_info(node.output[0], vi.type.tensor_type.elem_type, [input_rank, nz_len]))
1323
+
1324
+ def _infer_OneHot(self, node): # noqa: N802
1325
+ sympy_shape = self._get_sympy_shape(node, 0)
1326
+ depth = self._try_get_value(node, 1)
1327
+ axis = get_attribute(node, "axis", -1)
1328
+ axis = handle_negative_axis(axis, len(sympy_shape) + 1)
1329
+ new_shape = get_shape_from_sympy_shape(
1330
+ [
1331
+ *sympy_shape[:axis],
1332
+ self._new_symbolic_dim_from_output(node) if not is_literal(depth) else depth,
1333
+ *sympy_shape[axis:],
1334
+ ]
1335
+ )
1336
+ vi = self.known_vi_[node.output[0]]
1337
+ vi.CopyFrom(
1338
+ helper.make_tensor_value_info(
1339
+ node.output[0],
1340
+ self.known_vi_[node.input[2]].type.tensor_type.elem_type,
1341
+ new_shape,
1342
+ )
1343
+ )
1344
+
1345
+ def _infer_Pad(self, node): # noqa: N802
1346
+ if get_opset(self.out_mp_) <= 10:
1347
+ pads = get_attribute(node, "pads")
1348
+ else:
1349
+ pads = self._try_get_value(node, 1)
1350
+
1351
+ sympy_shape = self._get_sympy_shape(node, 0)
1352
+ rank = len(sympy_shape)
1353
+
1354
+ if pads is not None:
1355
+ assert len(pads) == 2 * rank
1356
+ new_sympy_shape = [
1357
+ d + pad_up + pad_down
1358
+ for d, pad_up, pad_down in zip(sympy_shape, pads[:rank], pads[rank:], strict=False)
1359
+ ]
1360
+ self._update_computed_dims(new_sympy_shape)
1361
+ else:
1362
+ # dynamic pads, create new symbolic dimensions
1363
+ new_sympy_shape = self._new_symbolic_shape(rank, node)
1364
+ output_tp = self.known_vi_[node.input[0]].type.tensor_type.elem_type
1365
+
1366
+ vi = self.known_vi_[node.output[0]]
1367
+ vi.CopyFrom(
1368
+ helper.make_tensor_value_info(node.output[0], output_tp, get_shape_from_sympy_shape(new_sympy_shape))
1369
+ )
1370
+
1371
+ def _infer_Pool(self, node): # noqa: N802
1372
+ sympy_shape = self._compute_conv_pool_shape(node)
1373
+ self._update_computed_dims(sympy_shape)
1374
+ for o in node.output:
1375
+ if not o:
1376
+ continue
1377
+ vi = self.known_vi_[o]
1378
+ vi.CopyFrom(
1379
+ helper.make_tensor_value_info(
1380
+ o,
1381
+ vi.type.tensor_type.elem_type,
1382
+ get_shape_from_sympy_shape(sympy_shape),
1383
+ )
1384
+ )
1385
+
1386
+ def _infer_aten_bitwise_or(self, node):
1387
+ shape0 = self._get_shape(node, 0)
1388
+ shape1 = self._get_shape(node, 1)
1389
+ new_shape = self._broadcast_shapes(shape0, shape1)
1390
+ t0 = self.known_vi_[node.input[0]]
1391
+ vi = self.known_vi_[node.output[0]]
1392
+ vi.CopyFrom(helper.make_tensor_value_info(node.output[0], t0.type.tensor_type.elem_type, new_shape))
1393
+
1394
+ def _infer_aten_diagonal(self, node):
1395
+ sympy_shape = self._get_sympy_shape(node, 0)
1396
+ rank = len(sympy_shape)
1397
+ offset = self._try_get_value(node, 1)
1398
+ dim1 = self._try_get_value(node, 2)
1399
+ dim2 = self._try_get_value(node, 3)
1400
+
1401
+ assert offset is not None and dim1 is not None and dim2 is not None
1402
+ dim1 = handle_negative_axis(dim1, rank)
1403
+ dim2 = handle_negative_axis(dim2, rank)
1404
+
1405
+ new_shape = []
1406
+ for dim, val in enumerate(sympy_shape):
1407
+ if dim not in [dim1, dim2]:
1408
+ new_shape.append(val)
1409
+
1410
+ shape1 = sympy_shape[dim1]
1411
+ shape2 = sympy_shape[dim2]
1412
+ if offset >= 0:
1413
+ diag_shape = sympy.Max(0, sympy.Min(shape1, shape2 - offset))
1414
+ else:
1415
+ diag_shape = sympy.Max(0, sympy.Min(shape1 + offset, shape2))
1416
+ new_shape.append(diag_shape)
1417
+
1418
+ if node.output[0]:
1419
+ vi = self.known_vi_[node.output[0]]
1420
+ vi.CopyFrom(
1421
+ helper.make_tensor_value_info(
1422
+ node.output[0],
1423
+ self.known_vi_[node.input[0]].type.tensor_type.elem_type,
1424
+ get_shape_from_sympy_shape(new_shape),
1425
+ )
1426
+ )
1427
+
1428
+ def _infer_aten_multinomial(self, node):
1429
+ sympy_shape = self._get_sympy_shape(node, 0)
1430
+ rank = len(sympy_shape)
1431
+ assert rank in [1, 2]
1432
+ num_samples = self._try_get_value(node, 1)
1433
+ di = rank - 1
1434
+ last_dim = num_samples if num_samples else str(self._new_symbolic_dim_from_output(node, 0, di))
1435
+ output_shape = [*sympy_shape[:-1], last_dim]
1436
+ vi = self.known_vi_[node.output[0]]
1437
+ vi.CopyFrom(
1438
+ helper.make_tensor_value_info(
1439
+ node.output[0],
1440
+ onnx.TensorProto.INT64,
1441
+ get_shape_from_sympy_shape(output_shape),
1442
+ )
1443
+ )
1444
+
1445
+ def _infer_aten_pool2d(self, node):
1446
+ sympy_shape = self._get_sympy_shape(node, 0)
1447
+ assert len(sympy_shape) == 4
1448
+ sympy_shape[-2:] = [self._new_symbolic_dim_from_output(node, 0, i) for i in [2, 3]]
1449
+ self._update_computed_dims(sympy_shape)
1450
+ for i, o in enumerate(node.output):
1451
+ if not o:
1452
+ continue
1453
+ vi = self.known_vi_[o]
1454
+ elem_type = onnx.TensorProto.INT64 if i == 1 else self.known_vi_[node.input[0]].type.tensor_type.elem_type
1455
+ vi.CopyFrom(helper.make_tensor_value_info(o, elem_type, get_shape_from_sympy_shape(sympy_shape)))
1456
+
1457
+ def _infer_aten_minmax(self, node):
1458
+ vi = self.known_vi_[node.output[0]]
1459
+ if len(node.input) == 1:
1460
+ vi.CopyFrom(
1461
+ helper.make_tensor_value_info(
1462
+ node.output[0], self.known_vi_[node.input[0]].type.tensor_type.elem_type, []
1463
+ )
1464
+ )
1465
+ else:
1466
+ assert len(node.input) == 3
1467
+ keepdim = self._try_get_value(node, 2)
1468
+ assert keepdim is not None # can only handle known keepdim case.
1469
+ dim = self._try_get_value(node, 1)
1470
+ if dim is None:
1471
+ rank = self._get_shape_rank(node, 0)
1472
+ output_shape = self._new_symbolic_shape(rank if keepdim else rank - 1, node)
1473
+ else:
1474
+ shape = self._get_sympy_shape(node, 0)
1475
+ dim = handle_negative_axis(dim, len(shape))
1476
+ output_shape = shape[:dim]
1477
+ if keepdim:
1478
+ output_shape += [1]
1479
+ output_shape += shape[dim + 1 :]
1480
+
1481
+ output_shape = get_shape_from_sympy_shape(output_shape)
1482
+ vi.CopyFrom(
1483
+ helper.make_tensor_value_info(
1484
+ node.output[0], self.known_vi_[node.input[0]].type.tensor_type.elem_type, output_shape
1485
+ )
1486
+ )
1487
+ vi1 = self.known_vi_[node.output[1]]
1488
+ vi1.CopyFrom(helper.make_tensor_value_info(node.output[1], onnx.TensorProto.INT64, output_shape))
1489
+
1490
+ def _infer_aten_unfold(self, node):
1491
+ sympy_shape = self._get_sympy_shape(node, 0)
1492
+ dimension = self._try_get_value(node, 1)
1493
+ size = self._try_get_value(node, 2)
1494
+ step = self._try_get_value(node, 3)
1495
+ if dimension is not None and size is not None and step is not None:
1496
+ assert dimension < len(sympy_shape)
1497
+ sympy_shape[dimension] = (sympy_shape[dimension] - size) // step + 1
1498
+ sympy_shape.append(size)
1499
+ else:
1500
+ rank = len(sympy_shape)
1501
+ sympy_shape = self._new_symbolic_shape(rank + 1, node)
1502
+ self._update_computed_dims(sympy_shape)
1503
+ if node.output[0]:
1504
+ vi = self.known_vi_[node.output[0]]
1505
+ vi.CopyFrom(
1506
+ helper.make_tensor_value_info(
1507
+ node.output[0],
1508
+ self.known_vi_[node.input[0]].type.tensor_type.elem_type,
1509
+ get_shape_from_sympy_shape(sympy_shape),
1510
+ )
1511
+ )
1512
+
1513
+ def _infer_aten_argmax(self, node):
1514
+ new_shape = None
1515
+ if not node.input[1]:
1516
+ # The argmax of the flattened input is returned.
1517
+ new_shape = []
1518
+ else:
1519
+ dim = self._try_get_value(node, 1)
1520
+ keepdim = self._try_get_value(node, 2)
1521
+ if keepdim is not None:
1522
+ sympy_shape = self._get_sympy_shape(node, 0)
1523
+ if dim is not None:
1524
+ dim = handle_negative_axis(dim, len(sympy_shape))
1525
+ if keepdim:
1526
+ sympy_shape[dim] = 1
1527
+ else:
1528
+ del sympy_shape[dim]
1529
+ else:
1530
+ rank = len(sympy_shape)
1531
+ sympy_shape = self._new_symbolic_shape(rank if keepdim else rank - 1, node)
1532
+ self._update_computed_dims(sympy_shape)
1533
+ new_shape = get_shape_from_sympy_shape(sympy_shape)
1534
+ if node.output[0] and new_shape is not None:
1535
+ vi = self.known_vi_[node.output[0]]
1536
+ vi.CopyFrom(helper.make_tensor_value_info(node.output[0], onnx.TensorProto.INT64, new_shape))
1537
+
1538
+ def _infer_aten_group_norm(self, node):
1539
+ self._propagate_shape_and_type(node)
1540
+ input_shape = self._get_shape(node, 0)
1541
+ N = input_shape[0] if input_shape is not None and len(input_shape) != 0 else None # noqa: N806
1542
+ group = self._try_get_value(node, 6)
1543
+ output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type
1544
+ for i in [1, 2]:
1545
+ if node.output[i]:
1546
+ vi = self.known_vi_[node.output[i]]
1547
+ vi.CopyFrom(
1548
+ helper.make_tensor_value_info(
1549
+ node.output[i],
1550
+ output_dtype,
1551
+ [
1552
+ N if N is not None else str(self._new_symbolic_dim_from_output(node, i, 0)),
1553
+ (
1554
+ as_scalar(group)
1555
+ if group is not None
1556
+ else str(self._new_symbolic_dim_from_output(node, i, 1))
1557
+ ),
1558
+ ],
1559
+ )
1560
+ )
1561
+
1562
+ def _infer_aten_upsample(self, node):
1563
+ new_shape = None
1564
+ input_shape = self._get_shape(node, 0)
1565
+ if input_shape is not None:
1566
+ new_shape = input_shape[:2]
1567
+ output_size = self._try_get_value(node, 1)
1568
+ if output_size is not None:
1569
+ new_shape += [dim_size.item() if type(dim_size) is np.int64 else dim_size for dim_size in output_size]
1570
+ else:
1571
+ rank = len(input_shape)
1572
+ new_shape += [str(self._new_symbolic_dim_from_output(node, 0, i)) for i in range(2, rank)]
1573
+ if node.output[0] and new_shape is not None:
1574
+ output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type
1575
+ vi = self.known_vi_[node.output[0]]
1576
+ vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, new_shape))
1577
+
1578
+ def _infer_BatchNormalization(self, node): # noqa: N802
1579
+ self._propagate_shape_and_type(node)
1580
+
1581
+ # this works for opsets < 14 and 14 since we check i < len(node.output) in the loop
1582
+ for i in [1, 2, 3, 4]:
1583
+ if i < len(node.output) and node.output[i]:
1584
+ # all of these parameters have the same shape as the 1st input
1585
+ self._propagate_shape_and_type(node, input_index=1, output_index=i)
1586
+
1587
+ def _infer_Range(self, node): # noqa: N802
1588
+ vi = self.known_vi_[node.output[0]]
1589
+ input_data = self._get_int_or_float_values(node)
1590
+ if all(i is not None for i in input_data):
1591
+ start = as_scalar(input_data[0])
1592
+ limit = as_scalar(input_data[1])
1593
+ delta = as_scalar(input_data[2])
1594
+ new_sympy_shape = [sympy.Max(sympy.ceiling((limit - start) / delta), 0)]
1595
+ else:
1596
+ new_sympy_shape = [self._new_symbolic_dim_from_output(node)]
1597
+ self._update_computed_dims(new_sympy_shape)
1598
+ vi.CopyFrom(
1599
+ helper.make_tensor_value_info(
1600
+ node.output[0],
1601
+ self.known_vi_[node.input[0]].type.tensor_type.elem_type,
1602
+ get_shape_from_sympy_shape(new_sympy_shape),
1603
+ )
1604
+ )
1605
+
1606
+ def _infer_ReduceSum(self, node): # noqa: N802
1607
+ keep_dims = get_attribute(node, "keepdims", 1)
1608
+ if get_opset(self.out_mp_) >= 13 and len(node.input) > 1:
1609
+ # ReduceSum changes axes to input[1] in opset 13
1610
+ axes = self._try_get_value(node, 1)
1611
+ vi = self.known_vi_[node.output[0]]
1612
+ if axes is None:
1613
+ assert keep_dims # can only handle keep_dims==True when axes is unknown, by generating new ranks
1614
+ vi.CopyFrom(
1615
+ helper.make_tensor_value_info(
1616
+ node.output[0],
1617
+ self.known_vi_[node.input[0]].type.tensor_type.elem_type,
1618
+ get_shape_from_sympy_shape(self._new_symbolic_shape(self._get_shape_rank(node, 0), node)),
1619
+ )
1620
+ )
1621
+ else:
1622
+ shape = self._get_shape(node, 0)
1623
+ output_shape = []
1624
+ axes = [handle_negative_axis(a, len(shape)) for a in axes]
1625
+ for i, d in enumerate(shape):
1626
+ if i in axes:
1627
+ if keep_dims:
1628
+ output_shape.append(1)
1629
+ else:
1630
+ output_shape.append(d)
1631
+ vi.CopyFrom(
1632
+ helper.make_tensor_value_info(
1633
+ node.output[0],
1634
+ self.known_vi_[node.input[0]].type.tensor_type.elem_type,
1635
+ output_shape,
1636
+ )
1637
+ )
1638
+
1639
+ def _infer_ReduceMean(self, node): # noqa: N802
1640
+ if get_opset(self.out_mp_) >= 18:
1641
+ # reduce mean spec 18+ is same as reduce sum spec 13+
1642
+ self._infer_ReduceSum(node)
1643
+
1644
+ def _infer_ReduceProd(self, node): # noqa: N802
1645
+ axes = get_attribute(node, "axes")
1646
+ keep_dims = get_attribute(node, "keepdims", 1)
1647
+ if keep_dims == 0 and axes == [0]:
1648
+ data = self._get_int_or_float_values(node)[0]
1649
+ if data is not None:
1650
+ self.sympy_data_[node.output[0]] = sympy_reduce_product(data)
1651
+
1652
+ def _infer_RelativePositionBias(self, node): # noqa: N802
1653
+ seq_len = self._try_get_value(node, 1)
1654
+ real_seq_len = self._try_get_value(node, 2)
1655
+ if seq_len is None or real_seq_len is None:
1656
+ return
1657
+ num_heads = self._get_sympy_shape(node, 0)[1]
1658
+
1659
+ new_shape = [1, num_heads, str(seq_len), str(real_seq_len)]
1660
+
1661
+ output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type
1662
+ vi = self.known_vi_[node.output[0]]
1663
+ vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, new_shape))
1664
+
1665
+ def _infer_Reshape(self, node): # noqa: N802
1666
+ shape_value = self._try_get_value(node, 1)
1667
+ vi = self.known_vi_[node.output[0]]
1668
+ if shape_value is None:
1669
+ shape_shape = self._get_shape(node, 1)
1670
+ assert len(shape_shape) == 1
1671
+ shape_rank = shape_shape[0]
1672
+ assert is_literal(shape_rank)
1673
+ vi.CopyFrom(
1674
+ helper.make_tensor_value_info(
1675
+ node.output[0],
1676
+ vi.type.tensor_type.elem_type,
1677
+ get_shape_from_sympy_shape(self._new_symbolic_shape(shape_rank, node)),
1678
+ )
1679
+ )
1680
+ else:
1681
+ input_sympy_shape = self._get_sympy_shape(node, 0)
1682
+ total = 1
1683
+ for d in input_sympy_shape:
1684
+ total = total * d
1685
+ new_sympy_shape = []
1686
+ deferred_dim_idx = -1
1687
+ non_deferred_size = 1
1688
+ for i, d in enumerate(shape_value):
1689
+ if type(d) is sympy.Symbol:
1690
+ new_sympy_shape.append(d)
1691
+ elif d == 0:
1692
+ new_sympy_shape.append(input_sympy_shape[i])
1693
+ non_deferred_size = non_deferred_size * input_sympy_shape[i]
1694
+ else:
1695
+ new_sympy_shape.append(d)
1696
+ if d == -1:
1697
+ deferred_dim_idx = i
1698
+ elif d != 0:
1699
+ non_deferred_size = non_deferred_size * d
1700
+
1701
+ assert new_sympy_shape.count(-1) < 2
1702
+ if -1 in new_sympy_shape:
1703
+ new_dim = total // non_deferred_size
1704
+ new_sympy_shape[deferred_dim_idx] = new_dim
1705
+
1706
+ self._update_computed_dims(new_sympy_shape)
1707
+ vi.CopyFrom(
1708
+ helper.make_tensor_value_info(
1709
+ node.output[0],
1710
+ vi.type.tensor_type.elem_type,
1711
+ get_shape_from_sympy_shape(new_sympy_shape),
1712
+ )
1713
+ )
1714
+
1715
+ self._pass_on_sympy_data(node)
1716
+
1717
+ def _infer_Resize(self, node): # noqa: N802
1718
+ vi = self.known_vi_[node.output[0]]
1719
+ input_sympy_shape = self._get_sympy_shape(node, 0)
1720
+ if get_opset(self.out_mp_) <= 10:
1721
+ scales = self._try_get_value(node, 1)
1722
+ if scales is not None:
1723
+ new_sympy_shape = [
1724
+ sympy.simplify(sympy.floor(d * s)) for d, s in zip(input_sympy_shape, scales, strict=False)
1725
+ ]
1726
+ self._update_computed_dims(new_sympy_shape)
1727
+ vi.CopyFrom(
1728
+ helper.make_tensor_value_info(
1729
+ node.output[0],
1730
+ self.known_vi_[node.input[0]].type.tensor_type.elem_type,
1731
+ get_shape_from_sympy_shape(new_sympy_shape),
1732
+ )
1733
+ )
1734
+ else:
1735
+ roi = self._try_get_value(node, 1)
1736
+ scales = self._try_get_value(node, 2)
1737
+ sizes = self._try_get_value(node, 3)
1738
+ if sizes is not None:
1739
+ new_sympy_shape = [sympy.simplify(sympy.floor(s)) for s in sizes]
1740
+ self._update_computed_dims(new_sympy_shape)
1741
+ elif scales is not None:
1742
+ rank = len(scales)
1743
+ if get_attribute(node, "coordinate_transformation_mode") == "tf_crop_and_resize":
1744
+ assert len(roi) == 2 * rank
1745
+ roi_start = list(roi)[:rank]
1746
+ roi_end = list(roi)[rank:]
1747
+ else:
1748
+ roi_start = [0] * rank
1749
+ roi_end = [1] * rank
1750
+ scales = list(scales)
1751
+ new_sympy_shape = [
1752
+ sympy.simplify(sympy.floor(d * (end - start) * scale))
1753
+ for d, start, end, scale in zip(input_sympy_shape, roi_start, roi_end, scales, strict=False)
1754
+ ]
1755
+ self._update_computed_dims(new_sympy_shape)
1756
+ else:
1757
+ new_sympy_shape = self._new_symbolic_shape(self._get_shape_rank(node, 0), node)
1758
+
1759
+ vi.CopyFrom(
1760
+ helper.make_tensor_value_info(
1761
+ node.output[0],
1762
+ self.known_vi_[node.input[0]].type.tensor_type.elem_type,
1763
+ get_shape_from_sympy_shape(new_sympy_shape),
1764
+ )
1765
+ )
1766
+
1767
+ def _infer_Scan(self, node): # noqa: N802
1768
+ subgraph = get_attribute(node, "body")
1769
+ num_scan_inputs = get_attribute(node, "num_scan_inputs")
1770
+ scan_input_axes = get_attribute(node, "scan_input_axes", [0] * num_scan_inputs)
1771
+ num_scan_states = len(node.input) - num_scan_inputs
1772
+ scan_input_axes = [
1773
+ handle_negative_axis(ax, self._get_shape_rank(node, i + num_scan_states))
1774
+ for i, ax in enumerate(scan_input_axes)
1775
+ ]
1776
+ # We may have cases where the subgraph has optional inputs that appear in both subgraph's input and initializer,
1777
+ # but not in the node's input. In such cases, the input model might be invalid, but let's skip those optional inputs.
1778
+ assert len(subgraph.input) >= len(node.input)
1779
+ subgraph_inputs = subgraph.input[: len(node.input)]
1780
+ for i, si in enumerate(subgraph_inputs):
1781
+ subgraph_name = si.name
1782
+ si.CopyFrom(self.known_vi_[node.input[i]])
1783
+ if i >= num_scan_states:
1784
+ scan_input_dim = si.type.tensor_type.shape.dim
1785
+ scan_input_dim.remove(scan_input_dim[scan_input_axes[i - num_scan_states]])
1786
+ si.name = subgraph_name
1787
+ self._onnx_infer_subgraph(node, subgraph)
1788
+ num_scan_outputs = len(node.output) - num_scan_states
1789
+ scan_output_axes = get_attribute(node, "scan_output_axes", [0] * num_scan_outputs)
1790
+ scan_input_dim = get_shape_from_type_proto(self.known_vi_[node.input[-1]].type)[scan_input_axes[-1]]
1791
+ for i, o in enumerate(node.output):
1792
+ vi = self.known_vi_[o]
1793
+ if i >= num_scan_states:
1794
+ shape = get_shape_from_type_proto(subgraph.output[i].type)
1795
+ new_dim = handle_negative_axis(scan_output_axes[i - num_scan_states], len(shape) + 1)
1796
+ shape = [*shape[:new_dim], scan_input_dim, *shape[new_dim:]]
1797
+ vi.CopyFrom(helper.make_tensor_value_info(o, subgraph.output[i].type.tensor_type.elem_type, shape))
1798
+ else:
1799
+ vi.CopyFrom(subgraph.output[i])
1800
+ vi.name = o
1801
+
1802
+ def _infer_ScatterElements(self, node): # noqa: N802
1803
+ data_shape = self._get_shape(node, 0)
1804
+ vi = self.known_vi_[node.output[0]]
1805
+ vi.CopyFrom(
1806
+ helper.make_tensor_value_info(
1807
+ node.output[0],
1808
+ self.known_vi_[node.input[0]].type.tensor_type.elem_type,
1809
+ data_shape,
1810
+ )
1811
+ )
1812
+
1813
+ def _infer_SequenceAt(self, node): # noqa: N802
1814
+ # need to create new symbolic dimension if sequence shape has None:
1815
+ seq_shape = self._get_shape(node, 0)
1816
+ vi = self.known_vi_[node.output[0]]
1817
+ if seq_shape is not None:
1818
+ for di, d in enumerate(seq_shape):
1819
+ if d is not None:
1820
+ continue
1821
+ new_dim = onnx.TensorShapeProto.Dimension()
1822
+ new_dim.dim_param = str(self._new_symbolic_dim_from_output(node, 0, di))
1823
+ vi.type.tensor_type.shape.dim[di].CopyFrom(new_dim)
1824
+
1825
+ def _infer_SequenceInsert(self, node): # noqa: N802
1826
+ # workaround bug in onnx's shape inference
1827
+ vi_seq = self.known_vi_[node.input[0]]
1828
+ vi_tensor = self.known_vi_[node.input[1]]
1829
+ vi_out_seq = self.known_vi_[node.output[0]]
1830
+ vi_out_seq.CopyFrom(vi_seq)
1831
+ vi_out_seq.name = node.output[0]
1832
+ self._fuse_tensor_type(node, 0, vi_out_seq.type, vi_tensor.type)
1833
+
1834
+ def _infer_Shape(self, node): # noqa: N802
1835
+ self.sympy_data_[node.output[0]] = self._get_sympy_shape(node, 0)
1836
+
1837
+ def _infer_Size(self, node): # noqa: N802
1838
+ sympy_shape = self._get_sympy_shape(node, 0)
1839
+ self.sympy_data_[node.output[0]] = sympy_reduce_product(sympy_shape)
1840
+ self.known_vi_[node.output[0]].CopyFrom(
1841
+ helper.make_tensor_value_info(node.output[0], onnx.TensorProto.INT64, [])
1842
+ )
1843
+
1844
+ def _infer_Slice(self, node): # noqa: N802
1845
+ # SymPy fails to prove that `x_0 + ... + x_n >= 0` if one of `x_i` is a `sympy.Min(a, b)`,
1846
+ # even when the relation holds for both `a` and `b`.
1847
+ #
1848
+ # When given `expr` of form `min(a, b) + ...`, this function returns `[a + ..., b + ...]`,
1849
+ # so that we can prove inequalities for both expressions separately.
1850
+ #
1851
+ # If the number of `min(...)` subexpressions is not exactly one, this function just returns `[expr]`.
1852
+ def flatten_min(expr):
1853
+ assert isinstance(expr, sympy.Add), f"Expected a sum of two arguments, got {expr}"
1854
+ min_positions = [idx for idx in range(len(expr.args)) if isinstance(expr.args[idx], sympy.Min)]
1855
+ if len(min_positions) == 1:
1856
+ min_pos = min_positions[0]
1857
+
1858
+ def replace_min_with_arg(arg_idx):
1859
+ replaced = list(expr.args)
1860
+ assert isinstance(replaced[min_pos], sympy.Min), (
1861
+ f"Expected a sympy.Min() at position {min_pos}, got {replaced[min_pos]}"
1862
+ )
1863
+ assert len(replaced[min_pos].args) == 2, (
1864
+ f"Expected a sympy.Min() with exactly 2 arguments, got {replaced[min_pos]}"
1865
+ )
1866
+ replaced[min_pos] = replaced[min_pos].args[arg_idx]
1867
+ return sympy.Add(*replaced)
1868
+
1869
+ return [
1870
+ replace_min_with_arg(0),
1871
+ replace_min_with_arg(1),
1872
+ ]
1873
+ return [expr]
1874
+
1875
+ def less_equal(x, y):
1876
+ try:
1877
+ return bool(x <= y)
1878
+ except TypeError:
1879
+ pass
1880
+ try:
1881
+ return bool(y >= x)
1882
+ except TypeError:
1883
+ pass
1884
+ try:
1885
+ return bool(-x >= -y)
1886
+ except TypeError:
1887
+ pass
1888
+ try:
1889
+ return bool(-y <= -x)
1890
+ except TypeError:
1891
+ pass
1892
+ try:
1893
+ return bool(y - x >= 0)
1894
+ except TypeError:
1895
+ # the last attempt; this may raise TypeError
1896
+ return all(bool(d >= 0) for d in flatten_min(y - x))
1897
+
1898
+ def handle_negative_index(index, bound):
1899
+ """normalizes a negative index to be in [0, bound)"""
1900
+ try:
1901
+ if not less_equal(0, index):
1902
+ if is_literal(index) and index <= -self.int_max_:
1903
+ # this case is handled separately
1904
+ return index
1905
+ return bound + index
1906
+ except TypeError:
1907
+ logger.warning(f"Cannot determine if {index} < 0")
1908
+ return index
1909
+
1910
+ if get_opset(self.out_mp_) <= 9:
1911
+ axes = get_attribute(node, "axes")
1912
+ starts = get_attribute(node, "starts")
1913
+ ends = get_attribute(node, "ends")
1914
+ if not axes:
1915
+ axes = list(range(len(starts)))
1916
+ steps = [1] * len(axes)
1917
+ else:
1918
+ starts = as_list(self._try_get_value(node, 1), keep_none=True)
1919
+ ends = as_list(self._try_get_value(node, 2), keep_none=True)
1920
+ axes = self._try_get_value(node, 3)
1921
+ steps = self._try_get_value(node, 4)
1922
+ if axes is None and not (starts is None and ends is None):
1923
+ axes = list(range(len(starts if starts is not None else ends)))
1924
+ if steps is None and not (starts is None and ends is None):
1925
+ steps = [1] * len(starts if starts is not None else ends)
1926
+ axes = as_list(axes, keep_none=True)
1927
+ steps = as_list(steps, keep_none=True)
1928
+
1929
+ new_sympy_shape = self._get_sympy_shape(node, 0)
1930
+ if starts is None or ends is None:
1931
+ if axes is None:
1932
+ for i in range(len(new_sympy_shape)):
1933
+ new_sympy_shape[i] = self._new_symbolic_dim_from_output(node, 0, i)
1934
+ else:
1935
+ new_sympy_shape = get_shape_from_sympy_shape(new_sympy_shape)
1936
+ for i in axes:
1937
+ new_sympy_shape[i] = self._new_symbolic_dim_from_output(node, 0, i)
1938
+ else:
1939
+ for i, s, e, t in zip(axes, starts, ends, steps, strict=False):
1940
+ e = handle_negative_index(e, new_sympy_shape[i]) # noqa: PLW2901
1941
+ if is_literal(e):
1942
+ if e >= self.int_max_:
1943
+ e = new_sympy_shape[i] # noqa: PLW2901
1944
+ elif e <= -self.int_max_:
1945
+ e = 0 if s > 0 else -1 # noqa: PLW2901
1946
+ elif is_literal(new_sympy_shape[i]):
1947
+ if e < 0:
1948
+ e = max(0, e + new_sympy_shape[i]) # noqa: PLW2901
1949
+ e = min(e, new_sympy_shape[i]) # noqa: PLW2901
1950
+ else:
1951
+ if e > 0:
1952
+ e = ( # noqa: PLW2901
1953
+ sympy.Min(e, new_sympy_shape[i]) if e > 1 else e
1954
+ ) # special case for slicing first to make computation easier
1955
+ else:
1956
+ if is_literal(new_sympy_shape[i]):
1957
+ e = sympy.Min(e, new_sympy_shape[i]) # noqa: PLW2901
1958
+ else:
1959
+ try:
1960
+ if not less_equal(e, new_sympy_shape[i]):
1961
+ e = new_sympy_shape[i] # noqa: PLW2901
1962
+ except Exception:
1963
+ logger.warning(f"Unable to determine if {e} <= {new_sympy_shape[i]}, treat as equal")
1964
+ e = new_sympy_shape[i] # noqa: PLW2901
1965
+
1966
+ s = handle_negative_index(s, new_sympy_shape[i]) # noqa: PLW2901
1967
+ if is_literal(new_sympy_shape[i]) and is_literal(s):
1968
+ s = max(0, min(s, new_sympy_shape[i])) # noqa: PLW2901
1969
+
1970
+ new_sympy_shape[i] = sympy.simplify((e - s + t + (-1 if t > 0 else 1)) // t)
1971
+
1972
+ self._update_computed_dims(new_sympy_shape)
1973
+
1974
+ vi = self.known_vi_[node.output[0]]
1975
+ vi.CopyFrom(
1976
+ helper.make_tensor_value_info(
1977
+ node.output[0],
1978
+ vi.type.tensor_type.elem_type,
1979
+ get_shape_from_sympy_shape(new_sympy_shape),
1980
+ )
1981
+ )
1982
+
1983
+ # handle sympy_data if needed, for slice in shape computation
1984
+ if (
1985
+ node.input[0] in self.sympy_data_
1986
+ and axes == [0]
1987
+ and starts is not None
1988
+ and len(starts) == 1
1989
+ and ends is not None
1990
+ and len(ends) == 1
1991
+ and steps is not None
1992
+ and len(steps) == 1
1993
+ ):
1994
+ input_sympy_data = self.sympy_data_[node.input[0]]
1995
+ if type(input_sympy_data) is list or (
1996
+ type(input_sympy_data) is np.array and len(input_sympy_data.shape) == 1
1997
+ ):
1998
+ self.sympy_data_[node.output[0]] = input_sympy_data[starts[0] : ends[0] : steps[0]]
1999
+
2000
+ def _infer_SoftmaxCrossEntropyLoss(self, node): # noqa: N802
2001
+ vi = self.known_vi_[node.output[0]]
2002
+ elem_type = self.known_vi_[node.input[0]].type.tensor_type.elem_type
2003
+
2004
+ # If output type is explicit specified in attribute, we use it as output tensor type.
2005
+ specified_output_type = get_attribute(node, "output_type", None)
2006
+ if specified_output_type is not None:
2007
+ elem_type = specified_output_type
2008
+
2009
+ vi.type.tensor_type.elem_type = elem_type
2010
+ vi.type.tensor_type.shape.CopyFrom(onnx.TensorShapeProto())
2011
+
2012
+ if len(node.output) > 1:
2013
+ data_shape = self._get_shape(node, 0)
2014
+ vi = self.known_vi_[node.output[1]]
2015
+ vi.CopyFrom(helper.make_tensor_value_info(vi.name, elem_type, data_shape))
2016
+
2017
+ def _infer_Split_Common(self, node, make_value_info_func): # noqa: N802
2018
+ input_sympy_shape = self._get_sympy_shape(node, 0)
2019
+ axis = handle_negative_axis(get_attribute(node, "axis", 0), len(input_sympy_shape))
2020
+ op_set = get_opset(self.out_mp_)
2021
+
2022
+ # Depending on op-version 'split' are provided as attribute or via 2nd input
2023
+ if op_set < 13:
2024
+ split = get_attribute(node, "split")
2025
+ assert self._try_get_value(node, 1) is None
2026
+ else:
2027
+ split = self._try_get_value(node, 1)
2028
+ assert get_attribute(node, "split") is None
2029
+
2030
+ if split is None:
2031
+ num_outputs = len(node.output)
2032
+ split = [input_sympy_shape[axis] / sympy.Integer(num_outputs)] * num_outputs
2033
+ self._update_computed_dims(split)
2034
+ else:
2035
+ split = [sympy.Integer(s) for s in split]
2036
+
2037
+ for i_o in range(len(split)):
2038
+ vi = self.known_vi_[node.output[i_o]]
2039
+ vi.CopyFrom(
2040
+ make_value_info_func(
2041
+ node.output[i_o],
2042
+ self.known_vi_[node.input[0]].type.tensor_type.elem_type,
2043
+ get_shape_from_sympy_shape([*input_sympy_shape[:axis], split[i_o], *input_sympy_shape[axis + 1 :]]),
2044
+ )
2045
+ )
2046
+ self.known_vi_[vi.name] = vi
2047
+
2048
+ def _infer_Split(self, node): # noqa: N802
2049
+ self._infer_Split_Common(node, helper.make_tensor_value_info)
2050
+
2051
+ def _infer_SplitToSequence(self, node): # noqa: N802
2052
+ self._infer_Split_Common(node, helper.make_sequence_value_info)
2053
+
2054
+ def _infer_Squeeze(self, node): # noqa: N802
2055
+ input_shape = self._get_shape(node, 0)
2056
+ op_set = get_opset(self.out_mp_)
2057
+
2058
+ # Depending on op-version 'axes' are provided as attribute or via 2nd input
2059
+ if op_set < 13:
2060
+ axes = get_attribute(node, "axes")
2061
+ assert self._try_get_value(node, 1) is None
2062
+ else:
2063
+ axes = self._try_get_value(node, 1)
2064
+ assert get_attribute(node, "axes") is None
2065
+
2066
+ if axes is None:
2067
+ # No axes have been provided (neither via attribute nor via input).
2068
+ # In this case the 'Shape' op should remove all axis with dimension 1.
2069
+ # For symbolic dimensions we guess they are !=1.
2070
+ output_shape = [s for s in input_shape if s != 1]
2071
+ if self.verbose_ > 0:
2072
+ symbolic_dimensions = [s for s in input_shape if type(s) != int] # noqa: E721
2073
+ if len(symbolic_dimensions) > 0:
2074
+ logger.debug(
2075
+ f"Symbolic dimensions in input shape of op: '{node.op_type}' node: '{node.name}'. "
2076
+ f"Assuming the following dimensions are never equal to 1: {symbolic_dimensions}"
2077
+ )
2078
+ else:
2079
+ axes = [handle_negative_axis(a, len(input_shape)) for a in axes]
2080
+ output_shape = []
2081
+ for i in range(len(input_shape)):
2082
+ if i not in axes:
2083
+ output_shape.append(input_shape[i])
2084
+ else:
2085
+ assert input_shape[i] == 1 or type(input_shape[i]) != int # noqa: E721
2086
+ if self.verbose_ > 0 and type(input_shape[i]) != int: # noqa: E721
2087
+ logger.debug(
2088
+ f"Symbolic dimensions in input shape of op: '{node.op_type}' node: '{node.name}'. "
2089
+ f"Assuming the dimension '{input_shape[i]}' at index {i} of the input to be equal to 1."
2090
+ )
2091
+
2092
+ vi = self.known_vi_[node.output[0]]
2093
+ vi.CopyFrom(
2094
+ helper.make_tensor_value_info(
2095
+ node.output[0],
2096
+ self.known_vi_[node.input[0]].type.tensor_type.elem_type,
2097
+ output_shape,
2098
+ )
2099
+ )
2100
+ self._pass_on_sympy_data(node)
2101
+
2102
+ def _infer_Tile(self, node): # noqa: N802
2103
+ repeats_value = self._try_get_value(node, 1)
2104
+ new_sympy_shape = []
2105
+ if repeats_value is not None:
2106
+ input_sympy_shape = self._get_sympy_shape(node, 0)
2107
+ for i, d in enumerate(input_sympy_shape):
2108
+ new_dim = d * repeats_value[i]
2109
+ new_sympy_shape.append(new_dim)
2110
+ self._update_computed_dims(new_sympy_shape)
2111
+ else:
2112
+ new_sympy_shape = self._new_symbolic_shape(self._get_shape_rank(node, 0), node)
2113
+ vi = self.known_vi_[node.output[0]]
2114
+ vi.CopyFrom(
2115
+ helper.make_tensor_value_info(
2116
+ node.output[0],
2117
+ vi.type.tensor_type.elem_type,
2118
+ get_shape_from_sympy_shape(new_sympy_shape),
2119
+ )
2120
+ )
2121
+
2122
+ def _infer_TopK(self, node): # noqa: N802
2123
+ rank = self._get_shape_rank(node, 0)
2124
+ axis = handle_negative_axis(get_attribute(node, "axis", -1), rank)
2125
+ new_shape = self._get_shape(node, 0)
2126
+
2127
+ if get_opset(self.out_mp_) <= 9:
2128
+ k = get_attribute(node, "k")
2129
+ else:
2130
+ k = self._get_int_or_float_values(node)[1]
2131
+
2132
+ if k is None:
2133
+ k = self._new_symbolic_dim_from_output(node)
2134
+ else:
2135
+ k = as_scalar(k)
2136
+
2137
+ if type(k) in [int, str]:
2138
+ new_shape[axis] = k
2139
+ else:
2140
+ new_sympy_shape = self._get_sympy_shape(node, 0)
2141
+ new_sympy_shape[axis] = k
2142
+ self._update_computed_dims(
2143
+ new_sympy_shape
2144
+ ) # note that TopK dim could be computed in sympy_data, so need to update computed_dims when it enters shape
2145
+ new_shape = get_shape_from_sympy_shape(new_sympy_shape)
2146
+
2147
+ for i_o in range(len(node.output)):
2148
+ vi = self.known_vi_[node.output[i_o]]
2149
+ vi.CopyFrom(helper.make_tensor_value_info(node.output[i_o], vi.type.tensor_type.elem_type, new_shape))
2150
+
2151
+ def _infer_Transpose(self, node): # noqa: N802
2152
+ if node.input[0] in self.sympy_data_:
2153
+ data_shape = self._get_shape(node, 0)
2154
+ perm = get_attribute(node, "perm", reversed(list(range(len(data_shape)))))
2155
+ input_data = self.sympy_data_[node.input[0]]
2156
+ self.sympy_data_[node.output[0]] = (
2157
+ np.transpose(np.array(input_data).reshape(*data_shape), axes=tuple(perm)).flatten().tolist()
2158
+ )
2159
+
2160
+ def _infer_Unsqueeze(self, node): # noqa: N802
2161
+ input_shape = self._get_shape(node, 0)
2162
+ op_set = get_opset(self.out_mp_)
2163
+
2164
+ # Depending on op-version 'axes' are provided as attribute or via 2nd input
2165
+ if op_set < 13:
2166
+ axes = get_attribute(node, "axes")
2167
+ assert self._try_get_value(node, 1) is None
2168
+ else:
2169
+ axes = self._try_get_value(node, 1)
2170
+ assert get_attribute(node, "axes") is None
2171
+
2172
+ output_rank = len(input_shape) + len(axes)
2173
+ axes = [handle_negative_axis(a, output_rank) for a in axes]
2174
+
2175
+ input_axis = 0
2176
+ output_shape = []
2177
+ for i in range(output_rank):
2178
+ if i in axes:
2179
+ output_shape.append(1)
2180
+ else:
2181
+ output_shape.append(input_shape[input_axis])
2182
+ input_axis += 1
2183
+
2184
+ vi = self.known_vi_[node.output[0]]
2185
+ vi.CopyFrom(
2186
+ helper.make_tensor_value_info(
2187
+ node.output[0],
2188
+ self.known_vi_[node.input[0]].type.tensor_type.elem_type,
2189
+ output_shape,
2190
+ )
2191
+ )
2192
+
2193
+ self._pass_on_sympy_data(node)
2194
+
2195
+ def _infer_ZipMap(self, node): # noqa: N802
2196
+ map_key_type = None
2197
+ if get_attribute(node, "classlabels_int64s") is not None:
2198
+ map_key_type = onnx.TensorProto.INT64
2199
+ elif get_attribute(node, "classlabels_strings") is not None:
2200
+ map_key_type = onnx.TensorProto.STRING
2201
+
2202
+ assert map_key_type is not None
2203
+ new_vi = onnx.ValueInfoProto()
2204
+ new_vi.name = node.output[0]
2205
+ new_vi.type.sequence_type.elem_type.map_type.value_type.tensor_type.elem_type = onnx.TensorProto.FLOAT
2206
+ new_vi.type.sequence_type.elem_type.map_type.key_type = map_key_type
2207
+ vi = self.known_vi_[node.output[0]]
2208
+ vi.CopyFrom(new_vi)
2209
+
2210
+ def _infer_Attention(self, node): # noqa: N802
2211
+ shape = self._get_shape(node, 0)
2212
+ shape_weights = self._get_shape(node, 1)
2213
+ shape_bias = self._try_get_shape(node, 2)
2214
+ if shape_bias is not None:
2215
+ assert len(shape_bias) == 1
2216
+ tripled_hidden_size = shape_bias[0] if shape_bias is not None else shape_weights[1]
2217
+ if shape and len(shape) == 3:
2218
+ qkv_hidden_sizes_attr = get_attribute(node, "qkv_hidden_sizes")
2219
+ if qkv_hidden_sizes_attr is not None:
2220
+ assert len(qkv_hidden_sizes_attr) == 3
2221
+ shape[2] = int(qkv_hidden_sizes_attr[2])
2222
+ elif isinstance(tripled_hidden_size, int):
2223
+ shape[2] = int(tripled_hidden_size / 3)
2224
+ output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type
2225
+ vi = self.known_vi_[node.output[0]]
2226
+ vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, shape))
2227
+
2228
+ if len(node.output) > 1:
2229
+ # input shape: (batch_size, sequence_length, hidden_size)
2230
+ # past shape: (2, batch_size, num_heads, past_sequence_length, head_size)
2231
+ # mask shape: (batch_size, total_sequence_length) or (batch_size, sequence_length, total_sequence_length) or (batch_size, 1, max_seq_len, max_seq_len)
2232
+ # present shape: (2, batch_size, num_heads, total_sequence_length, head_size), where total_sequence_length=sequence_length+past_sequence_length
2233
+ input_shape = self._get_shape(node, 0)
2234
+ past_shape = self._get_shape(node, 4) if len(node.input) > 4 and node.input[4] else []
2235
+ mask_shape = self._get_shape(node, 3) if len(node.input) > 3 and node.input[3] else []
2236
+
2237
+ if past_shape and len(past_shape) == 5:
2238
+ if mask_shape and len(mask_shape) in [2, 3]:
2239
+ past_shape[3] = mask_shape[-1]
2240
+ elif input_shape and len(input_shape) == 3:
2241
+ if isinstance(input_shape[1], int) and isinstance(past_shape[3], int):
2242
+ past_shape[3] = input_shape[1] + past_shape[3]
2243
+ else:
2244
+ past_shape[3] = f"{past_shape[3]}+{input_shape[1]}"
2245
+ vi = self.known_vi_[node.output[1]]
2246
+ vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, past_shape))
2247
+ # No past input but present output still exists
2248
+ else:
2249
+ num_heads = get_attribute(node, "num_heads")
2250
+ head_size = input_shape[2] // num_heads
2251
+ present_shape = [2, input_shape[0], num_heads, input_shape[1], head_size]
2252
+ vi = self.known_vi_[node.output[1]]
2253
+ vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, present_shape))
2254
+
2255
+ def _infer_GatedRelativePositionBias(self, node): # noqa: N802
2256
+ # When padding is removed:
2257
+ # query_layer: (token_count, num_heads x head_size)
2258
+ # token_offset: (batch_size, seq_len)
2259
+ # Otherwise:
2260
+ # query_layer: (batch_size, seq_len, num_heads x head_size)
2261
+ # token_offset: None
2262
+ # Output shape: (batch_size, num_heads, seq_len, seq_len)
2263
+ num_heads = get_attribute(node, "num_heads")
2264
+
2265
+ token_offset_shape = self._try_get_shape(node, 6)
2266
+ if token_offset_shape is not None:
2267
+ output_shape = [token_offset_shape[0], num_heads, token_offset_shape[1], token_offset_shape[1]]
2268
+ else:
2269
+ query_layer_shape = self._get_shape(node, 0)
2270
+ assert query_layer_shape is not None and len(query_layer_shape) == 3
2271
+ output_shape = [query_layer_shape[0], num_heads, query_layer_shape[1], query_layer_shape[1]]
2272
+
2273
+ output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type
2274
+ vi = self.known_vi_[node.output[0]]
2275
+ vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape))
2276
+
2277
+ def _infer_PackedAttention(self, node): # noqa: N802
2278
+ shape = self._get_shape(node, 0)
2279
+ shape_weights = self._get_shape(node, 1)
2280
+ shape_bias = self._try_get_shape(node, 2)
2281
+ if shape_bias is not None:
2282
+ assert len(shape_bias) == 1
2283
+ tripled_hidden_size = shape_bias[0] if shape_bias is not None else shape_weights[1]
2284
+ if shape and len(shape) == 2:
2285
+ qkv_hidden_sizes_attr = get_attribute(node, "qkv_hidden_sizes")
2286
+ if qkv_hidden_sizes_attr is not None:
2287
+ assert len(qkv_hidden_sizes_attr) == 3
2288
+ shape[1] = int(qkv_hidden_sizes_attr[2])
2289
+ elif isinstance(tripled_hidden_size, int):
2290
+ shape[1] = int(tripled_hidden_size / 3)
2291
+ output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type
2292
+ vi = self.known_vi_[node.output[0]]
2293
+ vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, shape))
2294
+
2295
+ def _infer_PackedMultiHeadAttention(self, node): # noqa: N802
2296
+ shape_value = self._try_get_shape(node, 2)
2297
+ if shape_value is not None and len(shape_value) == 2:
2298
+ output_shape = shape_value
2299
+ else:
2300
+ shape_query = self._get_shape(node, 0)
2301
+ assert shape_query is not None and len(shape_query) == 4
2302
+ output_shape = [shape_query[0], shape_query[1] * shape_query[3]]
2303
+
2304
+ output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type
2305
+ vi = self.known_vi_[node.output[0]]
2306
+ vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape))
2307
+
2308
+ def _infer_RemovePadding(self, node): # noqa: N802
2309
+ shape = self._get_shape(node, 0)
2310
+ if shape and len(shape) == 3:
2311
+ output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type
2312
+ vi = self.known_vi_[node.output[0]]
2313
+ vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, ["token_count", shape[2]]))
2314
+
2315
+ vi_token_offset = self.known_vi_[node.output[1]]
2316
+ vi_token_offset.CopyFrom(
2317
+ helper.make_tensor_value_info(node.output[1], onnx.TensorProto.INT32, [shape[0], shape[1]])
2318
+ )
2319
+
2320
+ vi_cumulated_seq_len = self.known_vi_[node.output[2]]
2321
+ vi_cumulated_seq_len.CopyFrom(
2322
+ helper.make_tensor_value_info(node.output[2], onnx.TensorProto.INT32, ["batch_size + 1"])
2323
+ )
2324
+
2325
+ vi_max_seq_len = self.known_vi_[node.output[3]]
2326
+ vi_max_seq_len.CopyFrom(helper.make_tensor_value_info(node.output[3], onnx.TensorProto.INT32, [1]))
2327
+
2328
+ def _infer_RestorePadding(self, node): # noqa: N802
2329
+ shape_input = self._get_shape(node, 0)
2330
+ shape_token_offset = self._get_shape(node, 1)
2331
+ if shape_input and len(shape_input) == 2 and shape_token_offset and len(shape_token_offset) == 2:
2332
+ output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type
2333
+ vi = self.known_vi_[node.output[0]]
2334
+
2335
+ output_shape = [shape_token_offset[0], shape_token_offset[1], shape_input[1]]
2336
+ vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape))
2337
+
2338
+ def _infer_BiasGelu(self, node): # noqa: N802
2339
+ self._propagate_shape_and_type(node)
2340
+
2341
+ def _infer_MultiHeadAttention(self, node): # noqa: N802
2342
+ # Output 0 has shape (batch_size, sequence_length, v_hidden_size)
2343
+ # Q, K and V without packing:
2344
+ # Input 0 (query) has shape (batch_size, sequence_length, hidden_size)
2345
+ # Input 1 (key) has shape (batch_size, kv_sequence_length, hidden_size) or (batch_size, num_heads, kv_sequence_length, head_size)
2346
+ # Input 2 (value) has shape (batch_size, kv_sequence_length, v_hidden_size) or (batch_size, num_heads, kv_sequence_length, head_size)
2347
+ # Packed KV:
2348
+ # Input 0 (query) has shape (batch_size, sequence_length, hidden_size)
2349
+ # Input 1 (batch_size, kv_sequence_length, num_heads, 2, head_size)
2350
+ # Input 2 nullptr
2351
+ # Packed QKV:
2352
+ # Input 0 (batch_size, sequence_length, num_heads, 3, head_size)
2353
+ # Input 1 nullptr
2354
+ # Input 2 nullptr
2355
+
2356
+ query_shape = self._get_shape(node, 0)
2357
+ total_sequence_length = None
2358
+ output_dtype = None
2359
+ if query_shape is not None:
2360
+ if len(query_shape) == 3:
2361
+ key_shape = self._try_get_shape(node, 1)
2362
+ # By default, hidden size is same for Q/K/V. Only need check v_hidden_size when value is provided.
2363
+ output_shape = query_shape
2364
+ if key_shape is not None and len(key_shape) == 3:
2365
+ value_shape = self._try_get_shape(node, 2)
2366
+ if value_shape is not None and len(value_shape) == 3:
2367
+ output_shape[2] = value_shape[2]
2368
+ total_sequence_length = key_shape[1]
2369
+
2370
+ output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type
2371
+ vi = self.known_vi_[node.output[0]]
2372
+ vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape))
2373
+
2374
+ elif len(query_shape) == 5:
2375
+ if isinstance(query_shape[2], int) and isinstance(query_shape[4], int):
2376
+ output_shape = [query_shape[0], query_shape[1], query_shape[2] * query_shape[4]]
2377
+ else:
2378
+ output_shape = [query_shape[0], query_shape[1], f"{query_shape[2]}*{query_shape[4]}"]
2379
+
2380
+ total_sequence_length = query_shape[1]
2381
+
2382
+ output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type
2383
+ vi = self.known_vi_[node.output[0]]
2384
+ vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape))
2385
+
2386
+ if len(node.output) > 1:
2387
+ batch_size = query_shape[0]
2388
+ num_heads = get_attribute(node, "num_heads")
2389
+
2390
+ head_size = None
2391
+ if len(query_shape) == 3:
2392
+ head_size = (
2393
+ int(query_shape[2] / num_heads)
2394
+ if isinstance(query_shape[2], int)
2395
+ else f"{query_shape[2]}/{num_heads}"
2396
+ )
2397
+ else:
2398
+ head_size = query_shape[4]
2399
+
2400
+ past_shape = self._try_get_shape(node, 6)
2401
+
2402
+ if past_shape is not None:
2403
+ if isinstance(past_shape[2], int) and isinstance(total_sequence_length, int):
2404
+ total_sequence_length = past_shape[2] + total_sequence_length
2405
+ else:
2406
+ total_sequence_length = f"{past_shape[2]}+{total_sequence_length}"
2407
+
2408
+ present_shape = [batch_size, num_heads, total_sequence_length, head_size]
2409
+
2410
+ assert output_dtype is not None
2411
+ if len(node.output) > 2 and node.output[1] and node.output[2]:
2412
+ vi = self.known_vi_[node.output[1]]
2413
+ vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, present_shape))
2414
+ vi = self.known_vi_[node.output[2]]
2415
+ vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, present_shape))
2416
+
2417
+ def _infer_DecoderMaskedMultiHeadAttention(self, node): # noqa: N802
2418
+ # Output 0 has shape (batch_size, 1, v_hidden_size)
2419
+ # Q, K and V without packing:
2420
+ # Input 0 (query) has shape (batch_size, 1, hidden_size)
2421
+ # Input 5 (past_key) if exists has shape (batch_size, num_heads, max_sequence_length, head_size)
2422
+
2423
+ query_shape = self._get_shape(node, 0)
2424
+ if query_shape is not None:
2425
+ output_shape = query_shape
2426
+ output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type
2427
+ assert output_dtype is not None
2428
+ vi = self.known_vi_[node.output[0]]
2429
+ vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape))
2430
+
2431
+ if len(node.output) > 2 and node.output[1] and node.output[2]:
2432
+ past_shape = self._try_get_shape(node, 5)
2433
+ if past_shape is not None:
2434
+ vi = self.known_vi_[node.output[1]]
2435
+ vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, past_shape))
2436
+ vi = self.known_vi_[node.output[2]]
2437
+ vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, past_shape))
2438
+
2439
+ def _infer_UnfoldTensor(self, node): # noqa: N802
2440
+ input_shape = self._get_shape(node, 0)
2441
+ if input_shape is not None:
2442
+ output_shape = input_shape.copy()
2443
+ output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type
2444
+ assert output_dtype is not None
2445
+
2446
+ rank, dim, size, step = len(input_shape), None, None, None
2447
+ for attr in node.attribute:
2448
+ if attr.name == "dim":
2449
+ dim = attr.i
2450
+ dim = rank + dim if dim == -1 else dim
2451
+ elif attr.name == "size":
2452
+ size = attr.i
2453
+ elif attr.name == "step":
2454
+ step = attr.i
2455
+
2456
+ output_shape.append(size)
2457
+ output_shape[dim] = (input_shape[dim] - size) // step + 1
2458
+
2459
+ vi = self.known_vi_[node.output[0]]
2460
+ vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape))
2461
+
2462
+ def _infer_DynamicTimeWarping(self, node): # noqa: N802
2463
+ # Input 0 has shape M x N or 1 x M x N
2464
+ # Output 0 has shape (2, O) where max(M, N) <= O < M + N
2465
+ input_shape = self._get_shape(node, 0)
2466
+ if input_shape is not None:
2467
+ shape_len = len(input_shape)
2468
+ assert shape_len == 2 or shape_len == 3
2469
+ M, N = input_shape[shape_len - 2], input_shape[shape_len - 1] # noqa: N806
2470
+ output_shape = [2, f"max({M}, {N}) <= O < {M} + {N}"]
2471
+ output_dtype = onnx.TensorProto.FLOAT
2472
+ vi = self.known_vi_[node.output[0]]
2473
+ vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape))
2474
+
2475
+ def _infer_FastGelu(self, node): # noqa: N802
2476
+ self._propagate_shape_and_type(node)
2477
+
2478
+ def _infer_Gelu(self, node): # noqa: N802
2479
+ self._propagate_shape_and_type(node)
2480
+
2481
+ def _infer_QuickGelu(self, node): # noqa: N802
2482
+ self._propagate_shape_and_type(node)
2483
+
2484
+ def _infer_GemmFastGelu(self, node): # noqa: N802
2485
+ self._compute_matmul_shape(node)
2486
+
2487
+ def _infer_GemmFloat8(self, node): # noqa: N802
2488
+ self._compute_matmul_shape(node)
2489
+
2490
+ def _infer_LayerNormalization(self, node): # noqa: N802
2491
+ self._propagate_shape_and_type(node)
2492
+ if len(node.output) > 1:
2493
+ axis = get_attribute(node, "axis")
2494
+ if axis is None:
2495
+ axis = -1
2496
+ x_shape = self._get_shape(node, 0)
2497
+ if x_shape is not None:
2498
+ rank = len(x_shape)
2499
+ axis = handle_negative_axis(axis, rank)
2500
+ mean_shape = x_shape[:axis] + [1 for _ in range(rank - axis)]
2501
+ mean_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type
2502
+ if mean_dtype == onnx.TensorProto.FLOAT16 or mean_dtype == onnx.TensorProto.BFLOAT16:
2503
+ mean_dtype = onnx.TensorProto.FLOAT
2504
+ vi = self.known_vi_[node.output[1]]
2505
+ vi.CopyFrom(helper.make_tensor_value_info(node.output[1], mean_dtype, mean_shape))
2506
+ if len(node.output) > 2:
2507
+ vi = self.known_vi_[node.output[2]]
2508
+ vi.CopyFrom(helper.make_tensor_value_info(node.output[2], mean_dtype, mean_shape))
2509
+
2510
+ def _infer_LongformerAttention(self, node): # noqa: N802
2511
+ self._propagate_shape_and_type(node)
2512
+
2513
+ def _infer_EmbedLayerNormalization(self, node): # noqa: N802
2514
+ input_ids_shape = self._get_shape(node, 0)
2515
+ word_embedding_shape = self._get_shape(node, 2)
2516
+ assert len(input_ids_shape) == 2 and len(word_embedding_shape) == 2
2517
+ output_shape = [*input_ids_shape, word_embedding_shape[1]]
2518
+
2519
+ word_embedding_dtype = self.known_vi_[node.input[2]].type.tensor_type.elem_type
2520
+ vi = self.known_vi_[node.output[0]]
2521
+ vi.CopyFrom(helper.make_tensor_value_info(node.output[0], word_embedding_dtype, output_shape))
2522
+
2523
+ if len(node.output) > 1 and node.output[1]:
2524
+ mask_index_shape = [input_ids_shape[0]]
2525
+ vi = self.known_vi_[node.output[1]]
2526
+ vi.CopyFrom(helper.make_tensor_value_info(node.output[1], onnx.TensorProto.INT32, mask_index_shape))
2527
+
2528
+ if len(node.output) > 2:
2529
+ # Optional output of add before layer normalization is done
2530
+ # shape is same as the output
2531
+ vi = self.known_vi_[node.output[2]]
2532
+ vi.CopyFrom(helper.make_tensor_value_info(node.output[2], word_embedding_dtype, output_shape))
2533
+
2534
+ def _infer_SkipLayerNormalization(self, node): # noqa: N802
2535
+ self._propagate_shape_and_type(node)
2536
+
2537
+ # If the SkipLayerNormalization node contains the optional
2538
+ # output for inference, infer the shape and type for it too
2539
+ if len(node.output) > 3:
2540
+ self._propagate_shape_and_type(node, 0, 3)
2541
+
2542
+ def _infer_GroupNorm(self, node): # noqa: N802
2543
+ self._propagate_shape_and_type(node)
2544
+
2545
+ def _infer_PagedAttention(self, node): # noqa: N802
2546
+ self._propagate_shape_and_type(node)
2547
+
2548
+ def _infer_GroupQueryAttention(self, node): # noqa: N802
2549
+ output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type
2550
+
2551
+ past_shape = self._try_get_shape(node, 3)
2552
+ if past_shape is not None:
2553
+ # When past and present has the maximum sequence length, we can propagate the shape from past to present.
2554
+ # Note that GQA also supports different sequence lengths for past and present, but it is rarely used.
2555
+ vi = self.known_vi_[node.output[1]]
2556
+ vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, past_shape))
2557
+ vi = self.known_vi_[node.output[2]]
2558
+ vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, past_shape))
2559
+
2560
+ if node.input[1] != "" and node.input[2] != "":
2561
+ self._propagate_shape_and_type(node, 0, 0)
2562
+ else:
2563
+ # combined qkv: (batch_size, sequence_length, num_heads * head_size + 2 * kv_num_heads * head_size)
2564
+ assert node.input[1] == "" and node.input[2] == ""
2565
+ num_heads = get_attribute(node, "num_heads")
2566
+ kv_num_heads = get_attribute(node, "kv_num_heads")
2567
+ query_shape = self._get_shape(node, 0)
2568
+ if query_shape is not None:
2569
+ hidden_size = query_shape[2]
2570
+ if isinstance(hidden_size, int):
2571
+ head_size = int(hidden_size / (num_heads + 2 * kv_num_heads))
2572
+ query_shape[2] = num_heads * head_size
2573
+ vi = self.known_vi_[node.output[0]]
2574
+ vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, query_shape))
2575
+
2576
+ def _infer_SparseAttention(self, node): # noqa: N802
2577
+ self._infer_GroupQueryAttention(node)
2578
+
2579
+ def _infer_SkipGroupNorm(self, node): # noqa: N802
2580
+ self._propagate_shape_and_type(node, 0, 0)
2581
+ if len(node.output) > 1:
2582
+ self._propagate_shape_and_type(node, 0, 1)
2583
+
2584
+ def _infer_BiasSplitGelu(self, node): # noqa: N802
2585
+ input_shape = self._get_shape(node, 0)
2586
+ bias_shape = self._get_shape(node, 1)
2587
+ if input_shape and bias_shape and isinstance(bias_shape[0], int):
2588
+ output_shape = input_shape
2589
+ output_shape[2] = int(bias_shape[0] / 2)
2590
+ vi = self.known_vi_[node.output[0]]
2591
+ output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type
2592
+ vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, output_shape))
2593
+
2594
+ def _infer_BiasAdd(self, node): # noqa: N802
2595
+ self._propagate_shape_and_type(node)
2596
+
2597
+ def _infer_RotaryEmbedding(self, node): # noqa: N802
2598
+ if len(node.output) == 1:
2599
+ self._propagate_shape_and_type(node)
2600
+ elif len(node.output) == 2:
2601
+ # Extraneous constant nodes outputted by RotaryEmbedding function made with `export_modules_as_functions`
2602
+ self._propagate_shape_and_type(node, input_index=1, output_index=0)
2603
+ self._propagate_shape_and_type(node, input_index=0, output_index=1) # true output
2604
+ elif len(node.output) == 3:
2605
+ # Extraneous constant nodes outputted by RotaryEmbedding function made with `export_modules_as_functions`
2606
+ self._propagate_shape_and_type(node, input_index=1, output_index=0)
2607
+ self._propagate_shape_and_type(node, input_index=1, output_index=1)
2608
+ self._propagate_shape_and_type(node, input_index=0, output_index=2) # true output
2609
+
2610
+ def _infer_PythonOp(self, node): # noqa: N802
2611
+ output_tensor_types = get_attribute(node, "output_tensor_types")
2612
+ assert output_tensor_types, f"PythonOp '{node.name}' has no output_tensor_types attribute."
2613
+ output_tensor_ranks = get_attribute(node, "output_tensor_ranks")
2614
+ assert output_tensor_ranks, f"PythonOp '{node.name}' has no output_tensor_ranks attribute."
2615
+
2616
+ from onnxruntime.capi._pybind_state import get_shape_inference_function # noqa: PLC0415
2617
+
2618
+ func_name = get_attribute(node, "func_name").decode()
2619
+ shape_inferer = get_shape_inference_function(func_name)
2620
+
2621
+ # Set the context output separately.
2622
+ # The first output is torch.autograd.Function''s context.
2623
+ vi = self.known_vi_[node.output[0]]
2624
+ vi.CopyFrom(helper.make_tensor_value_info(node.output[0], onnx.TensorProto.INT64, []))
2625
+
2626
+ if shape_inferer is not None:
2627
+ input_shapes = []
2628
+ input_dtypes = []
2629
+ for input_index in range(len(node.input)):
2630
+ shape = self._get_shape(node, input_index)
2631
+ input_shapes.append(shape)
2632
+ input_dtype = self.known_vi_[node.input[input_index]].type.tensor_type.elem_type
2633
+ input_dtypes.append(input_dtype)
2634
+ output_shapes, output_dtypes = shape_inferer(node, input_shapes, input_dtypes)
2635
+ assert len(output_shapes) == len(output_dtypes) == (len(node.output) - 1), (
2636
+ f"PythonOp '{func_name}' returned {len(output_shapes)} shapes and {len(output_dtypes)} dtypes, "
2637
+ f"but expected {len(node.output) - 1} outputs."
2638
+ )
2639
+ for i in range(len(node.output) - 1):
2640
+ output_index = i + 1
2641
+ vi = self.known_vi_[node.output[output_index]]
2642
+ vi.CopyFrom(
2643
+ helper.make_tensor_value_info(node.output[output_index], output_dtypes[i], output_shapes[i])
2644
+ )
2645
+ else:
2646
+ # General shape inference for PythonOp.
2647
+ # Outputs after torch.autograd.Function's context are tensors.
2648
+ # We assume their ranks are fixed for different model inputs.
2649
+ for i in range(len(node.output) - 1):
2650
+ # Process the i-th tensor outputs.
2651
+ vi = self.known_vi_[node.output[i + 1]]
2652
+ sympy_shape = self._new_symbolic_shape(output_tensor_ranks[i], node)
2653
+ shape = get_shape_from_sympy_shape(sympy_shape)
2654
+ value_info = helper.make_tensor_value_info(node.output[i + 1], output_tensor_types[i], shape)
2655
+ vi.CopyFrom(value_info)
2656
+
2657
+ def _propagate_shape_and_type(self, node, input_index=0, output_index=0):
2658
+ shape = self._get_shape(node, input_index)
2659
+ output_dtype = self.known_vi_[node.input[input_index]].type.tensor_type.elem_type
2660
+ vi = self.known_vi_[node.output[output_index]]
2661
+ vi.CopyFrom(helper.make_tensor_value_info(node.output[output_index], output_dtype, shape))
2662
+
2663
+ def _is_none_dim(self, dim_value):
2664
+ if type(dim_value) != str: # noqa: E721
2665
+ return False
2666
+ if "unk__" not in dim_value:
2667
+ return False
2668
+ if dim_value in self.symbolic_dims_:
2669
+ return False
2670
+ return True
2671
+
2672
+ def _is_shape_contains_none_dim(self, out_shape):
2673
+ for out in out_shape:
2674
+ if self._is_none_dim(out):
2675
+ return out
2676
+ return None
2677
+
2678
+ def _infer_impl(self, start_sympy_data=None):
2679
+ self.sympy_data_ = start_sympy_data or {}
2680
+ self.out_mp_.graph.ClearField("value_info")
2681
+ self._apply_suggested_merge(graph_input_only=True)
2682
+ self.input_symbols_ = set()
2683
+ for i in self.out_mp_.graph.input:
2684
+ input_shape = get_shape_from_value_info(i)
2685
+ if input_shape is None:
2686
+ continue
2687
+
2688
+ if is_sequence(i.type):
2689
+ input_dims = i.type.sequence_type.elem_type.tensor_type.shape.dim
2690
+ else:
2691
+ input_dims = i.type.tensor_type.shape.dim
2692
+
2693
+ for i_dim, dim in enumerate(input_shape):
2694
+ if dim is None:
2695
+ # some models use None for symbolic dim in input, replace it with a string
2696
+ input_dims[i_dim].dim_param = str(self._new_symbolic_dim(i.name, i_dim))
2697
+
2698
+ self.input_symbols_.update([d for d in input_shape if type(d) is str])
2699
+
2700
+ for s in self.input_symbols_:
2701
+ if s in self.suggested_merge_:
2702
+ s_merge = self.suggested_merge_[s]
2703
+ assert s_merge in self.symbolic_dims_
2704
+ self.symbolic_dims_[s] = self.symbolic_dims_[s_merge]
2705
+ else:
2706
+ # Since inputs are not produced by other ops, we can assume positivity
2707
+ self.symbolic_dims_[s] = sympy.Symbol(s, integer=True, positive=True)
2708
+ # create a temporary ModelProto for single node inference
2709
+ # note that we remove initializer to have faster inference
2710
+ # for tensor ops like Reshape/Tile/Expand that read initializer, we need to do sympy computation based inference anyways
2711
+ self.tmp_mp_ = onnx.ModelProto()
2712
+ self.tmp_mp_.CopyFrom(self.out_mp_)
2713
+ self.tmp_mp_.graph.ClearField("initializer")
2714
+
2715
+ # compute prerequesite for node for topological sort
2716
+ # node with subgraphs may have dependency on implicit inputs, which will affect topological sort
2717
+ prereq_for_node = {} # map from node to all its inputs, including implicit ones in subgraph
2718
+
2719
+ def get_prereq(node):
2720
+ names = {i for i in node.input if i}
2721
+ subgraphs = []
2722
+ if node.op_type == "If":
2723
+ subgraphs = [
2724
+ get_attribute(node, "then_branch"),
2725
+ get_attribute(node, "else_branch"),
2726
+ ]
2727
+ elif node.op_type in ["Loop", "Scan"]:
2728
+ subgraphs = [get_attribute(node, "body")]
2729
+ for g in subgraphs:
2730
+ g_outputs_and_initializers = {i.name for i in g.initializer}
2731
+ g_prereq = set()
2732
+ for n in g.node:
2733
+ g_outputs_and_initializers.update(n.output)
2734
+ for n in g.node:
2735
+ g_prereq.update([i for i in get_prereq(n) if i not in g_outputs_and_initializers])
2736
+ names.update(g_prereq)
2737
+ # remove subgraph inputs from g_prereq since those are local-only
2738
+ for i in g.input:
2739
+ names.discard(i.name)
2740
+ return names
2741
+
2742
+ for n in self.tmp_mp_.graph.node:
2743
+ prereq_for_node[n.output[0]] = get_prereq(n)
2744
+
2745
+ # topological sort nodes, note there might be dead nodes so we check if all graph outputs are reached to terminate
2746
+ sorted_nodes = []
2747
+ sorted_known_vi = {i.name for i in list(self.out_mp_.graph.input) + list(self.out_mp_.graph.initializer)}
2748
+ if any(o.name in sorted_known_vi for o in self.out_mp_.graph.output):
2749
+ # Loop/Scan will have some graph output in graph inputs, so don't do topological sort
2750
+ sorted_nodes = self.out_mp_.graph.node
2751
+ else:
2752
+ while not all(o.name in sorted_known_vi for o in self.out_mp_.graph.output):
2753
+ old_sorted_nodes_len = len(sorted_nodes)
2754
+ for node in self.out_mp_.graph.node:
2755
+ if (node.output[0] not in sorted_known_vi) and all(
2756
+ i in sorted_known_vi for i in prereq_for_node[node.output[0]] if i
2757
+ ):
2758
+ sorted_known_vi.update(node.output)
2759
+ sorted_nodes.append(node)
2760
+ if old_sorted_nodes_len == len(sorted_nodes) and not all(
2761
+ o.name in sorted_known_vi for o in self.out_mp_.graph.output
2762
+ ):
2763
+ raise Exception("Invalid model with cyclic graph")
2764
+
2765
+ for node in sorted_nodes:
2766
+ assert all(i in self.known_vi_ for i in node.input if i)
2767
+ self._onnx_infer_single_node(node)
2768
+ known_aten_op = False
2769
+ if node.op_type in self.dispatcher_:
2770
+ self.dispatcher_[node.op_type](node)
2771
+ elif node.op_type in ["ConvTranspose"]:
2772
+ # onnx shape inference ops like ConvTranspose may have empty shape for symbolic input
2773
+ # before adding symbolic compute for them
2774
+ # mark the output type as UNDEFINED to allow guessing of rank
2775
+ vi = self.known_vi_[node.output[0]]
2776
+ if len(vi.type.tensor_type.shape.dim) == 0:
2777
+ vi.type.tensor_type.elem_type = onnx.TensorProto.UNDEFINED
2778
+ elif node.op_type == "ATen" and node.domain == "org.pytorch.aten":
2779
+ for attr in node.attribute:
2780
+ # TODO: Is overload_name needed?
2781
+ if attr.name == "operator":
2782
+ aten_op_name = attr.s.decode("utf-8") if isinstance(attr.s, bytes) else attr.s
2783
+ if aten_op_name in self.aten_op_dispatcher_:
2784
+ known_aten_op = True
2785
+ self.aten_op_dispatcher_[aten_op_name](node)
2786
+ break
2787
+
2788
+ if self.verbose_ > 2:
2789
+ logger.debug(node.op_type + ": " + node.name) # noqa: G003
2790
+ for i, name in enumerate(node.input):
2791
+ logger.debug(" Input %s: %s %s", i, name, "initializer" if name in self.initializers_ else "")
2792
+
2793
+ # onnx automatically merge dims with value, i.e. Mul(['aaa', 'bbb'], [1000, 1]) -> [1000, 'bbb']
2794
+ # symbolic shape inference needs to apply merge of 'aaa' -> 1000 in this case
2795
+ if node.op_type in [
2796
+ "Add",
2797
+ "Sub",
2798
+ "Mul",
2799
+ "Div",
2800
+ "MatMul",
2801
+ "MatMulInteger",
2802
+ "MatMulInteger16",
2803
+ "Where",
2804
+ "Sum",
2805
+ ]:
2806
+ vi = self.known_vi_[node.output[0]]
2807
+ out_rank = len(get_shape_from_type_proto(vi.type))
2808
+ in_shapes = [self._get_shape(node, i) for i in range(len(node.input))]
2809
+ for d in range(out_rank - (2 if node.op_type in ["MatMul", "MatMulInteger", "MatMulInteger16"] else 0)):
2810
+ in_dims = [s[len(s) - out_rank + d] for s in in_shapes if len(s) + d >= out_rank]
2811
+ if len(in_dims) > 1:
2812
+ self._check_merged_dims(in_dims, allow_broadcast=True)
2813
+
2814
+ for i_o in range(len(node.output)):
2815
+ # Special cases:
2816
+ # 1) We do not care about the training related outputs of SkipLayerNormalization
2817
+ # 2) We do not care about the extraneous constant outputs in RotaryEmbedding because
2818
+ # the RotaryEmbedding op created during export can be replaced by the RotaryEmbedding
2819
+ # contrib op
2820
+ if (
2821
+ node.op_type == "SkipLayerNormalization" or node.op_type == "SkipSimplifiedLayerNormalization"
2822
+ ) and i_o in [1, 2]:
2823
+ continue
2824
+ if node.op_type == "RotaryEmbedding" and len(node.output) > 1:
2825
+ # Skip symbolic shape inference for RotaryEmbedding functions that have extraneous outputs
2826
+ # generated by `export_modules_as_functions`
2827
+ continue
2828
+
2829
+ vi = self.known_vi_[node.output[i_o]]
2830
+ out_type = vi.type
2831
+ out_type_kind = out_type.WhichOneof("value")
2832
+
2833
+ # do not process shape for non-tensors
2834
+ if out_type_kind not in ["tensor_type", "sparse_tensor_type", None]:
2835
+ if self.verbose_ > 2:
2836
+ if out_type_kind == "sequence_type":
2837
+ seq_cls_type = out_type.sequence_type.elem_type.WhichOneof("value")
2838
+ if seq_cls_type == "tensor_type":
2839
+ logger.debug(
2840
+ " {}: sequence of {} {}".format( # noqa: G001
2841
+ node.output[i_o],
2842
+ str(get_shape_from_value_info(vi)),
2843
+ onnx.TensorProto.DataType.Name(
2844
+ vi.type.sequence_type.elem_type.tensor_type.elem_type
2845
+ ),
2846
+ )
2847
+ )
2848
+ else:
2849
+ logger.debug(f" {node.output[i_o]}: sequence of {seq_cls_type}")
2850
+ else:
2851
+ logger.debug(f" {node.output[i_o]}: {out_type_kind}")
2852
+ continue
2853
+
2854
+ out_shape = get_shape_from_value_info(vi)
2855
+ out_type_undefined = out_type.tensor_type.elem_type == onnx.TensorProto.UNDEFINED
2856
+ if self.verbose_ > 2:
2857
+ logger.debug(
2858
+ f" {node.output[i_o]}: {out_shape!s} {onnx.TensorProto.DataType.Name(vi.type.tensor_type.elem_type)}"
2859
+ )
2860
+ if node.output[i_o] in self.sympy_data_:
2861
+ logger.debug(" Sympy Data: " + str(self.sympy_data_[node.output[i_o]])) # noqa: G003
2862
+
2863
+ # onnx >= 1.11.0, use unk__#index instead of None when the shape dim is uncertain
2864
+ if (
2865
+ out_shape is not None and (None in out_shape or self._is_shape_contains_none_dim(out_shape))
2866
+ ) or out_type_undefined:
2867
+ if self.auto_merge_:
2868
+ if node.op_type in [
2869
+ "Add",
2870
+ "Sub",
2871
+ "Mul",
2872
+ "Div",
2873
+ "MatMul",
2874
+ "MatMulInteger",
2875
+ "MatMulInteger16",
2876
+ "Concat",
2877
+ "Where",
2878
+ "Sum",
2879
+ "Equal",
2880
+ "Less",
2881
+ "Greater",
2882
+ "LessOrEqual",
2883
+ "GreaterOrEqual",
2884
+ "Min",
2885
+ "Max",
2886
+ ]:
2887
+ shapes = [self._get_shape(node, i) for i in range(len(node.input))]
2888
+ if node.op_type in [
2889
+ "MatMul",
2890
+ "MatMulInteger",
2891
+ "MatMulInteger16",
2892
+ ]:
2893
+ if None in out_shape or self._is_shape_contains_none_dim(out_shape):
2894
+ if None in out_shape:
2895
+ idx = out_shape.index(None)
2896
+ else:
2897
+ idx = out_shape.index(self._is_shape_contains_none_dim(out_shape))
2898
+ dim_idx = [len(s) - len(out_shape) + idx for s in shapes]
2899
+ # only support auto merge for MatMul for dim < rank-2 when rank > 2
2900
+ assert len(shapes[0]) > 2 and dim_idx[0] < len(shapes[0]) - 2
2901
+ assert len(shapes[1]) > 2 and dim_idx[1] < len(shapes[1]) - 2
2902
+ elif node.op_type == "Expand":
2903
+ # auto merge for cases like Expand([min(batch, 1), min(seq, 512)], [batch, seq])
2904
+ shapes = [
2905
+ self._get_shape(node, 0),
2906
+ self._get_value(node, 1),
2907
+ ]
2908
+ else:
2909
+ shapes = []
2910
+
2911
+ if shapes:
2912
+ for idx in range(len(out_shape)):
2913
+ if out_shape[idx] is not None and not self._is_none_dim(out_shape[idx]):
2914
+ continue
2915
+ # note that the broadcasting rule aligns from right to left
2916
+ # if a tensor has a lower rank (dim_idx[idx] < 0), it would automatically broadcast and need no merge
2917
+ dim_idx = [len(s) - len(out_shape) + idx for s in shapes]
2918
+ if len(dim_idx) > 0:
2919
+ self._add_suggested_merge(
2920
+ [
2921
+ s[i] if is_literal(s[i]) else str(s[i])
2922
+ for s, i in zip(shapes, dim_idx, strict=False)
2923
+ if i >= 0
2924
+ ]
2925
+ )
2926
+ self.run_ = True
2927
+ else:
2928
+ self.run_ = False
2929
+ else:
2930
+ self.run_ = False
2931
+
2932
+ # create new dynamic dims for ops not handled by symbolic shape inference
2933
+ if self.run_ is False and node.op_type not in self.dispatcher_ and not known_aten_op:
2934
+ is_unknown_op = out_type_undefined and (out_shape is None or len(out_shape) == 0)
2935
+ if is_unknown_op:
2936
+ # unknown op to ONNX, maybe from higher opset or other domain
2937
+ # only guess the output rank from input 0 when using guess_output_rank option
2938
+ out_rank = self._get_shape_rank(node, 0) if self.guess_output_rank_ else -1
2939
+ else:
2940
+ # valid ONNX op, but not handled by symbolic shape inference, just assign dynamic shape
2941
+ out_rank = len(out_shape)
2942
+
2943
+ if out_rank >= 0:
2944
+ new_shape = self._new_symbolic_shape(out_rank, node, i_o)
2945
+ if out_type_undefined:
2946
+ # guess output data type from input vi if not defined
2947
+ out_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type
2948
+ else:
2949
+ # otherwise, use original data type
2950
+ out_dtype = vi.type.tensor_type.elem_type
2951
+ vi.CopyFrom(
2952
+ helper.make_tensor_value_info(
2953
+ vi.name,
2954
+ out_dtype,
2955
+ get_shape_from_sympy_shape(new_shape),
2956
+ )
2957
+ )
2958
+
2959
+ if self.verbose_ > 0:
2960
+ if is_unknown_op:
2961
+ logger.debug(
2962
+ f"Possible unknown op: {node.op_type} node: {node.name}, guessing {vi.name} shape"
2963
+ )
2964
+ if self.verbose_ > 2:
2965
+ logger.debug(f" {node.output[i_o]}: {new_shape!s} {vi.type.tensor_type.elem_type}")
2966
+
2967
+ self.run_ = True
2968
+ continue # continue the inference after guess, no need to stop as no merge is needed
2969
+
2970
+ if self.verbose_ > 0 or not self.auto_merge_ or out_type_undefined:
2971
+ logger.debug("Stopping at incomplete shape inference at %s: %s", node.op_type, node.name)
2972
+ logger.debug("node inputs:")
2973
+ for i in node.input:
2974
+ if i in self.known_vi_:
2975
+ logger.debug(self.known_vi_[i])
2976
+ else:
2977
+ logger.debug(f"not in known_vi_ for {i}")
2978
+ logger.debug("node outputs:")
2979
+ for o in node.output:
2980
+ if o in self.known_vi_:
2981
+ logger.debug(self.known_vi_[o])
2982
+ else:
2983
+ logger.debug(f"not in known_vi_ for {o}")
2984
+ if self.auto_merge_ and not out_type_undefined:
2985
+ logger.debug("Merging: " + str(self.suggested_merge_)) # noqa: G003
2986
+ return False
2987
+
2988
+ self.run_ = False
2989
+ return True
2990
+
2991
+ def _update_output_from_vi(self):
2992
+ for output in self.out_mp_.graph.output:
2993
+ if output.name in self.known_vi_:
2994
+ output.CopyFrom(self.known_vi_[output.name])
2995
+
2996
+ @staticmethod
2997
+ def infer_shapes(in_mp, int_max=2**31 - 1, auto_merge=False, guess_output_rank=False, verbose=0):
2998
+ onnx_opset = get_opset(in_mp)
2999
+ if (not onnx_opset) or onnx_opset < 7:
3000
+ logger.warning("Only support models of onnx opset 7 and above.")
3001
+ return None
3002
+ symbolic_shape_inference = SymbolicShapeInference(int_max, auto_merge, guess_output_rank, verbose)
3003
+ all_shapes_inferred = False
3004
+ symbolic_shape_inference._preprocess(in_mp)
3005
+ while symbolic_shape_inference.run_:
3006
+ all_shapes_inferred = symbolic_shape_inference._infer_impl()
3007
+ symbolic_shape_inference._update_output_from_vi()
3008
+ if not all_shapes_inferred:
3009
+ onnx.save_model(symbolic_shape_inference.out_mp_, "sym_shape_infer_temp.onnx", save_as_external_data=True)
3010
+ raise Exception("Incomplete symbolic shape inference")
3011
+ return symbolic_shape_inference.out_mp_
3012
+
3013
+
3014
+ def parse_arguments():
3015
+ parser = argparse.ArgumentParser()
3016
+ parser.add_argument("--input", required=True, help="The input model file")
3017
+ parser.add_argument("--output", help="The output model file")
3018
+ parser.add_argument(
3019
+ "--auto_merge",
3020
+ help="Automatically merge symbolic dims when confliction happens",
3021
+ action="store_true",
3022
+ default=False,
3023
+ )
3024
+ parser.add_argument(
3025
+ "--int_max",
3026
+ help="maximum value for integer to be treated as boundless for ops like slice",
3027
+ type=int,
3028
+ default=2**31 - 1,
3029
+ )
3030
+ parser.add_argument(
3031
+ "--guess_output_rank",
3032
+ help="guess output rank to be the same as input 0 for unknown ops",
3033
+ action="store_true",
3034
+ default=False,
3035
+ )
3036
+ parser.add_argument(
3037
+ "--verbose",
3038
+ help="Prints detailed logs of inference, 0: turn off, 1: warnings, 3: detailed",
3039
+ type=int,
3040
+ default=0,
3041
+ )
3042
+ parser.add_argument(
3043
+ "--save_as_external_data",
3044
+ help="Saving an ONNX model to external data",
3045
+ action="store_true",
3046
+ default=False,
3047
+ )
3048
+ parser.add_argument(
3049
+ "--all_tensors_to_one_file",
3050
+ help="Saving all the external data to one file",
3051
+ action="store_true",
3052
+ default=False,
3053
+ )
3054
+ parser.add_argument(
3055
+ "--external_data_location",
3056
+ help="The file location to save the external file",
3057
+ default="./",
3058
+ )
3059
+ parser.add_argument(
3060
+ "--external_data_size_threshold",
3061
+ help="The size threshold for external data",
3062
+ type=int,
3063
+ default=1024,
3064
+ )
3065
+ return parser.parse_args()
3066
+
3067
+
3068
+ if __name__ == "__main__":
3069
+ args = parse_arguments()
3070
+ logger.info("input model: " + args.input) # noqa: G003
3071
+ if args.output:
3072
+ logger.info("output model " + args.output) # noqa: G003
3073
+ logger.info("Doing symbolic shape inference...")
3074
+ out_mp = SymbolicShapeInference.infer_shapes(
3075
+ onnx.load(args.input),
3076
+ args.int_max,
3077
+ args.auto_merge,
3078
+ args.guess_output_rank,
3079
+ args.verbose,
3080
+ )
3081
+ if args.output and out_mp:
3082
+ if args.save_as_external_data:
3083
+ onnx.save_model(
3084
+ out_mp,
3085
+ args.output,
3086
+ save_as_external_data=True,
3087
+ all_tensors_to_one_file=args.all_tensors_to_one_file,
3088
+ location=args.external_data_location,
3089
+ size_threshold=args.external_data_size_threshold,
3090
+ convert_attribute=False,
3091
+ )
3092
+ else:
3093
+ onnx.save(out_mp, args.output)
3094
+ logger.info("Done!")