onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (322) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6121 -0
  4. onnxruntime/__init__.py +418 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +175 -0
  7. onnxruntime/backend/backend_rep.py +52 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/build_and_package_info.py +2 -0
  13. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  14. onnxruntime/capi/onnxruntime.dll +0 -0
  15. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  16. onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
  17. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  18. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  19. onnxruntime/capi/onnxruntime_validation.py +154 -0
  20. onnxruntime/capi/version_info.py +2 -0
  21. onnxruntime/datasets/__init__.py +18 -0
  22. onnxruntime/datasets/logreg_iris.onnx +0 -0
  23. onnxruntime/datasets/mul_1.onnx +0 -0
  24. onnxruntime/datasets/sigmoid.onnx +13 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  27. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  28. onnxruntime/quantization/__init__.py +19 -0
  29. onnxruntime/quantization/base_quantizer.py +529 -0
  30. onnxruntime/quantization/calibrate.py +1267 -0
  31. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  32. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  33. onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
  34. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  35. onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
  36. onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
  37. onnxruntime/quantization/fusions/__init__.py +4 -0
  38. onnxruntime/quantization/fusions/fusion.py +311 -0
  39. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  40. onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
  41. onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
  42. onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
  43. onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
  44. onnxruntime/quantization/neural_compressor/__init__.py +1 -0
  45. onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
  46. onnxruntime/quantization/neural_compressor/util.py +80 -0
  47. onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
  48. onnxruntime/quantization/onnx_model.py +600 -0
  49. onnxruntime/quantization/onnx_quantizer.py +1163 -0
  50. onnxruntime/quantization/operators/__init__.py +2 -0
  51. onnxruntime/quantization/operators/activation.py +119 -0
  52. onnxruntime/quantization/operators/argmax.py +18 -0
  53. onnxruntime/quantization/operators/attention.py +73 -0
  54. onnxruntime/quantization/operators/base_operator.py +26 -0
  55. onnxruntime/quantization/operators/binary_op.py +72 -0
  56. onnxruntime/quantization/operators/concat.py +62 -0
  57. onnxruntime/quantization/operators/conv.py +260 -0
  58. onnxruntime/quantization/operators/direct_q8.py +78 -0
  59. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  60. onnxruntime/quantization/operators/gather.py +64 -0
  61. onnxruntime/quantization/operators/gavgpool.py +62 -0
  62. onnxruntime/quantization/operators/gemm.py +172 -0
  63. onnxruntime/quantization/operators/lstm.py +121 -0
  64. onnxruntime/quantization/operators/matmul.py +231 -0
  65. onnxruntime/quantization/operators/maxpool.py +34 -0
  66. onnxruntime/quantization/operators/norm.py +40 -0
  67. onnxruntime/quantization/operators/pad.py +172 -0
  68. onnxruntime/quantization/operators/pooling.py +67 -0
  69. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  70. onnxruntime/quantization/operators/resize.py +34 -0
  71. onnxruntime/quantization/operators/softmax.py +74 -0
  72. onnxruntime/quantization/operators/split.py +63 -0
  73. onnxruntime/quantization/operators/where.py +87 -0
  74. onnxruntime/quantization/preprocess.py +141 -0
  75. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  76. onnxruntime/quantization/qdq_quantizer.py +1477 -0
  77. onnxruntime/quantization/quant_utils.py +1051 -0
  78. onnxruntime/quantization/quantize.py +953 -0
  79. onnxruntime/quantization/registry.py +110 -0
  80. onnxruntime/quantization/shape_inference.py +204 -0
  81. onnxruntime/quantization/static_quantize_runner.py +256 -0
  82. onnxruntime/quantization/tensor_quant_overrides.py +520 -0
  83. onnxruntime/tools/__init__.py +10 -0
  84. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  85. onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
  86. onnxruntime/tools/file_utils.py +47 -0
  87. onnxruntime/tools/logger.py +11 -0
  88. onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
  89. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  90. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
  91. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  92. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  93. onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
  94. onnxruntime/tools/offline_tuning.py +169 -0
  95. onnxruntime/tools/onnx_model_utils.py +416 -0
  96. onnxruntime/tools/onnx_randomizer.py +85 -0
  97. onnxruntime/tools/onnxruntime_test.py +164 -0
  98. onnxruntime/tools/optimize_onnx_model.py +56 -0
  99. onnxruntime/tools/ort_format_model/__init__.py +27 -0
  100. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  140. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  141. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  142. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  143. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  144. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  145. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  146. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  147. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  148. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  149. onnxruntime/tools/ort_format_model/types.py +85 -0
  150. onnxruntime/tools/ort_format_model/utils.py +61 -0
  151. onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
  152. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  153. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  154. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  155. onnxruntime/tools/qnn/add_trans_cast.py +292 -0
  156. onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
  157. onnxruntime/tools/qnn/preprocess.py +165 -0
  158. onnxruntime/tools/reduced_build_config_parser.py +203 -0
  159. onnxruntime/tools/remove_initializer_from_input.py +37 -0
  160. onnxruntime/tools/symbolic_shape_infer.py +3094 -0
  161. onnxruntime/tools/update_onnx_opset.py +31 -0
  162. onnxruntime/transformers/__init__.py +8 -0
  163. onnxruntime/transformers/affinity_helper.py +40 -0
  164. onnxruntime/transformers/benchmark.py +942 -0
  165. onnxruntime/transformers/benchmark_helper.py +643 -0
  166. onnxruntime/transformers/bert_perf_test.py +629 -0
  167. onnxruntime/transformers/bert_test_data.py +641 -0
  168. onnxruntime/transformers/compare_bert_results.py +256 -0
  169. onnxruntime/transformers/constants.py +47 -0
  170. onnxruntime/transformers/convert_generation.py +3605 -0
  171. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  172. onnxruntime/transformers/convert_to_packing_mode.py +385 -0
  173. onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
  174. onnxruntime/transformers/float16.py +501 -0
  175. onnxruntime/transformers/fusion_attention.py +1189 -0
  176. onnxruntime/transformers/fusion_attention_clip.py +340 -0
  177. onnxruntime/transformers/fusion_attention_sam2.py +533 -0
  178. onnxruntime/transformers/fusion_attention_unet.py +1307 -0
  179. onnxruntime/transformers/fusion_attention_vae.py +300 -0
  180. onnxruntime/transformers/fusion_bart_attention.py +435 -0
  181. onnxruntime/transformers/fusion_base.py +141 -0
  182. onnxruntime/transformers/fusion_bias_add.py +57 -0
  183. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  184. onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
  185. onnxruntime/transformers/fusion_conformer_attention.py +222 -0
  186. onnxruntime/transformers/fusion_constant_fold.py +144 -0
  187. onnxruntime/transformers/fusion_embedlayer.py +810 -0
  188. onnxruntime/transformers/fusion_fastgelu.py +492 -0
  189. onnxruntime/transformers/fusion_gelu.py +258 -0
  190. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  191. onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
  192. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  193. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  194. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  195. onnxruntime/transformers/fusion_group_norm.py +180 -0
  196. onnxruntime/transformers/fusion_layernorm.py +489 -0
  197. onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
  198. onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
  199. onnxruntime/transformers/fusion_options.py +340 -0
  200. onnxruntime/transformers/fusion_qordered_attention.py +420 -0
  201. onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
  202. onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
  203. onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
  204. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  205. onnxruntime/transformers/fusion_reshape.py +173 -0
  206. onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
  207. onnxruntime/transformers/fusion_shape.py +109 -0
  208. onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
  209. onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
  210. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  211. onnxruntime/transformers/fusion_transpose.py +167 -0
  212. onnxruntime/transformers/fusion_utils.py +321 -0
  213. onnxruntime/transformers/huggingface_models.py +74 -0
  214. onnxruntime/transformers/import_utils.py +20 -0
  215. onnxruntime/transformers/io_binding_helper.py +487 -0
  216. onnxruntime/transformers/large_model_exporter.py +395 -0
  217. onnxruntime/transformers/machine_info.py +230 -0
  218. onnxruntime/transformers/metrics.py +163 -0
  219. onnxruntime/transformers/models/bart/__init__.py +12 -0
  220. onnxruntime/transformers/models/bart/export.py +98 -0
  221. onnxruntime/transformers/models/bert/__init__.py +12 -0
  222. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  223. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  224. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  225. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
  226. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
  227. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  228. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  229. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  230. onnxruntime/transformers/models/llama/__init__.py +12 -0
  231. onnxruntime/transformers/models/llama/benchmark.py +700 -0
  232. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  233. onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
  234. onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
  235. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  236. onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
  237. onnxruntime/transformers/models/llama/llama_parity.py +343 -0
  238. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  239. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  240. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  241. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  242. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  243. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  244. onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
  245. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  246. onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
  247. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  248. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  249. onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
  250. onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
  251. onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
  252. onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
  253. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  254. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  255. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  256. onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
  257. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
  258. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  259. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  260. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
  261. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  262. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
  263. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
  264. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  265. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
  266. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
  267. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
  268. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
  269. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  270. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  271. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  272. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
  273. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  274. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  275. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  276. onnxruntime/transformers/models/t5/__init__.py +12 -0
  277. onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
  278. onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
  279. onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
  280. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
  281. onnxruntime/transformers/models/t5/t5_helper.py +302 -0
  282. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  283. onnxruntime/transformers/models/whisper/benchmark.py +585 -0
  284. onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
  285. onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
  286. onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
  287. onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
  288. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  289. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
  290. onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
  291. onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
  292. onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
  293. onnxruntime/transformers/onnx_exporter.py +719 -0
  294. onnxruntime/transformers/onnx_model.py +1636 -0
  295. onnxruntime/transformers/onnx_model_bart.py +141 -0
  296. onnxruntime/transformers/onnx_model_bert.py +488 -0
  297. onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
  298. onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
  299. onnxruntime/transformers/onnx_model_clip.py +42 -0
  300. onnxruntime/transformers/onnx_model_conformer.py +32 -0
  301. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  302. onnxruntime/transformers/onnx_model_mmdit.py +112 -0
  303. onnxruntime/transformers/onnx_model_phi.py +929 -0
  304. onnxruntime/transformers/onnx_model_sam2.py +137 -0
  305. onnxruntime/transformers/onnx_model_t5.py +985 -0
  306. onnxruntime/transformers/onnx_model_tnlr.py +226 -0
  307. onnxruntime/transformers/onnx_model_unet.py +258 -0
  308. onnxruntime/transformers/onnx_model_vae.py +42 -0
  309. onnxruntime/transformers/onnx_utils.py +55 -0
  310. onnxruntime/transformers/optimizer.py +620 -0
  311. onnxruntime/transformers/past_helper.py +149 -0
  312. onnxruntime/transformers/profile_result_processor.py +358 -0
  313. onnxruntime/transformers/profiler.py +434 -0
  314. onnxruntime/transformers/quantize_helper.py +76 -0
  315. onnxruntime/transformers/shape_infer_helper.py +121 -0
  316. onnxruntime/transformers/shape_optimizer.py +400 -0
  317. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  318. onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
  319. onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
  320. onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
  321. onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
  322. onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,434 @@
1
+ import argparse
2
+ import os
3
+
4
+ import numpy
5
+ import psutil
6
+ from onnx import TensorProto
7
+
8
+ """
9
+ This profiler tool could run a transformer model and print out the kernel time spent on each Node of the model.
10
+ Example of profiling of longformer model:
11
+ python profiler.py --model longformer-base-4096_fp32.onnx --batch_size 1 --sequence_length 4096 --global_length 8 --samples 1000 --thread_num 8 --dummy_inputs longformer --use_gpu
12
+ Example of importing profile result file from onnxruntime_perf_test:
13
+ python profiler.py --input profile_2021-10-25_12-02-41.json
14
+ """
15
+
16
+
17
+ def parse_arguments(argv=None):
18
+ parser = argparse.ArgumentParser()
19
+
20
+ parser.add_argument(
21
+ "-i",
22
+ "--input",
23
+ required=False,
24
+ type=str,
25
+ help="Set the input file for reading the profile results",
26
+ )
27
+
28
+ parser.add_argument(
29
+ "-m",
30
+ "--model",
31
+ required=False,
32
+ type=str,
33
+ help="onnx model path to run profiling. Required when --input is not specified.",
34
+ )
35
+
36
+ parser.add_argument(
37
+ "-b",
38
+ "--batch_size",
39
+ required=False,
40
+ type=int,
41
+ default=1,
42
+ help="batch size of input",
43
+ )
44
+
45
+ parser.add_argument(
46
+ "-s",
47
+ "--sequence_length",
48
+ required=False,
49
+ type=int,
50
+ default=32,
51
+ help="sequence length of input",
52
+ )
53
+
54
+ parser.add_argument(
55
+ "--past_sequence_length",
56
+ required=False,
57
+ type=int,
58
+ default=1,
59
+ help="past sequence length for gpt2",
60
+ )
61
+
62
+ parser.add_argument(
63
+ "--global_length",
64
+ required=False,
65
+ type=int,
66
+ default=1,
67
+ help="number of global tokens for longformer",
68
+ )
69
+
70
+ parser.add_argument(
71
+ "--samples",
72
+ required=False,
73
+ type=int,
74
+ default=1000,
75
+ help="number of samples to test. Set it large enough to reduce the variance of performance result.",
76
+ )
77
+
78
+ parser.add_argument(
79
+ "--threshold",
80
+ required=False,
81
+ type=float,
82
+ default=0.01,
83
+ help="Threshold of run time ratio among all nodes. Nodes with larger ratio will show in top expensive nodes.",
84
+ )
85
+
86
+ parser.add_argument(
87
+ "--thread_num",
88
+ required=False,
89
+ type=int,
90
+ default=-1,
91
+ help="number of threads to use",
92
+ )
93
+
94
+ parser.add_argument(
95
+ "--input_ids_name",
96
+ required=False,
97
+ type=str,
98
+ default=None,
99
+ help="input name for input IDs, for bert",
100
+ )
101
+ parser.add_argument(
102
+ "--segment_ids_name",
103
+ required=False,
104
+ type=str,
105
+ default=None,
106
+ help="input name for segment IDs, for bert",
107
+ )
108
+ parser.add_argument(
109
+ "--input_mask_name",
110
+ required=False,
111
+ type=str,
112
+ default=None,
113
+ help="input name for attention mask, for bert",
114
+ )
115
+
116
+ parser.add_argument(
117
+ "--dummy_inputs",
118
+ required=False,
119
+ default="default",
120
+ choices=["bert", "gpt2", "longformer", "default"],
121
+ help="Type of model inputs. The default will create dummy inputs with ones.",
122
+ )
123
+
124
+ parser.add_argument("-g", "--use_gpu", required=False, action="store_true", help="use GPU")
125
+ parser.set_defaults(use_gpu=False)
126
+
127
+ parser.add_argument(
128
+ "--provider",
129
+ required=False,
130
+ type=str,
131
+ default="cuda",
132
+ help="Execution provider to use",
133
+ )
134
+
135
+ parser.add_argument(
136
+ "--basic_optimization",
137
+ required=False,
138
+ action="store_true",
139
+ help="Enable only basic graph optimizations. By default, all optimizations are enabled in OnnxRuntime",
140
+ )
141
+ parser.set_defaults(basic_optimization=False)
142
+
143
+ parser.add_argument(
144
+ "--kernel_time_only",
145
+ required=False,
146
+ action="store_true",
147
+ help="Only include the kernel time and no fence time",
148
+ )
149
+ parser.set_defaults(kernel_time_only=False)
150
+
151
+ parser.add_argument("-v", "--verbose", required=False, action="store_true")
152
+ parser.set_defaults(verbose=False)
153
+
154
+ return parser.parse_args(argv)
155
+
156
+
157
+ def run_profile(onnx_model_path, use_gpu, provider, basic_optimization, thread_num, all_inputs):
158
+ from benchmark_helper import create_onnxruntime_session # noqa: PLC0415
159
+
160
+ session = create_onnxruntime_session(
161
+ onnx_model_path,
162
+ use_gpu,
163
+ provider,
164
+ enable_all_optimization=not basic_optimization,
165
+ num_threads=thread_num,
166
+ enable_profiling=True,
167
+ )
168
+
169
+ for inputs in all_inputs:
170
+ _ = session.run(None, inputs)
171
+
172
+ profile_file = session.end_profiling()
173
+ return profile_file
174
+
175
+
176
+ def get_dim_from_type_proto(dim):
177
+ return getattr(dim, dim.WhichOneof("value")) if type(dim.WhichOneof("value")) == str else None # noqa: E721
178
+
179
+
180
+ def get_shape_from_type_proto(type_proto):
181
+ return [get_dim_from_type_proto(d) for d in type_proto.tensor_type.shape.dim]
182
+
183
+
184
+ def create_dummy_inputs(onnx_model, batch_size, sequence_length, samples):
185
+ """Create dummy inputs for ONNX model.
186
+
187
+ Args:
188
+ onnx_model (OnnxModel): ONNX model
189
+ batch_size (int): batch size
190
+ sequence_length (int): sequence length
191
+ samples (int): number of samples
192
+
193
+ Returns:
194
+ List[Dict]: list of inputs
195
+ """
196
+ dummy_inputs = {}
197
+ for graph_input in onnx_model.get_graph_inputs_excluding_initializers():
198
+ shape = get_shape_from_type_proto(graph_input.type)
199
+ symbol_dims = []
200
+ for i, dim in enumerate(shape):
201
+ if isinstance(dim, str):
202
+ symbol_dims.append(i)
203
+
204
+ # allowed symbolic dimensions: batch_size and sequence_length
205
+ if len(symbol_dims) > 2:
206
+ return None
207
+ if len(symbol_dims) > 0:
208
+ shape[symbol_dims[0]] = batch_size
209
+ if len(symbol_dims) > 1:
210
+ shape[symbol_dims[1]] = sequence_length
211
+
212
+ elem_type = graph_input.type.tensor_type.elem_type
213
+ assert elem_type in [TensorProto.FLOAT, TensorProto.INT32, TensorProto.INT64]
214
+ data_type = (
215
+ numpy.float32
216
+ if elem_type == TensorProto.FLOAT
217
+ else (numpy.int64 if elem_type == TensorProto.INT64 else numpy.int32)
218
+ )
219
+ data = numpy.ones(shape, dtype=data_type)
220
+ dummy_inputs[graph_input.name] = data
221
+
222
+ all_inputs = [dummy_inputs for _ in range(samples)]
223
+ return all_inputs
224
+
225
+
226
+ def create_bert_inputs(
227
+ onnx_model,
228
+ batch_size,
229
+ sequence_length,
230
+ samples,
231
+ input_ids_name=None,
232
+ segment_ids_name=None,
233
+ input_mask_name=None,
234
+ ):
235
+ """Create dummy inputs for BERT model.
236
+
237
+ Args:
238
+ onnx_model (OnnxModel): ONNX model
239
+ batch_size (int): batch size
240
+ sequence_length (int): sequence length
241
+ samples (int): number of samples
242
+ input_ids_name (str, optional): Name of graph input for input IDs. Defaults to None.
243
+ segment_ids_name (str, optional): Name of graph input for segment IDs. Defaults to None.
244
+ input_mask_name (str, optional): Name of graph input for attention mask. Defaults to None.
245
+
246
+ Returns:
247
+ List[Dict]: list of inputs
248
+ """
249
+ from bert_test_data import find_bert_inputs, generate_test_data # noqa: PLC0415
250
+
251
+ input_ids, segment_ids, input_mask = find_bert_inputs(onnx_model, input_ids_name, segment_ids_name, input_mask_name)
252
+ all_inputs = generate_test_data(
253
+ batch_size,
254
+ sequence_length,
255
+ test_cases=samples,
256
+ seed=123,
257
+ verbose=False,
258
+ input_ids=input_ids,
259
+ segment_ids=segment_ids,
260
+ input_mask=input_mask,
261
+ random_mask_length=False,
262
+ )
263
+
264
+ return all_inputs
265
+
266
+
267
+ def create_gpt2_inputs(onnx_model, batch_size, sequence_length, past_sequence_length, samples):
268
+ """Create dummy inputs for GPT-2 model.
269
+
270
+ Args:
271
+ onnx_model (OnnxModel): ONNX model
272
+ batch_size (int): batch size
273
+ sequence_length (int): sequence length
274
+ past_sequence_length (int): past sequence length
275
+ samples (int): number of samples
276
+
277
+ Raises:
278
+ RuntimeError: symbolic is not supported. Use the tool convert_to_onnx.py to export ONNX model instead.
279
+
280
+ Returns:
281
+ List[Dict]: list of inputs
282
+ """
283
+ # The symbolic names shall be same as those used in Gpt2Helper.export_onnx(...) function.
284
+ symbols = {
285
+ "batch_size": batch_size,
286
+ "seq_len": sequence_length,
287
+ "past_seq_len": past_sequence_length,
288
+ "total_seq_len": sequence_length + past_sequence_length,
289
+ }
290
+
291
+ dummy_inputs = {}
292
+ for graph_input in onnx_model.get_graph_inputs_excluding_initializers():
293
+ shape = get_shape_from_type_proto(graph_input.type)
294
+ for i, dim in enumerate(shape):
295
+ if isinstance(dim, str):
296
+ if dim not in symbols:
297
+ raise RuntimeError(f"symbol is not supported: {dim}")
298
+ else:
299
+ shape[i] = symbols[dim]
300
+
301
+ elem_type = graph_input.type.tensor_type.elem_type
302
+ assert elem_type in [TensorProto.FLOAT, TensorProto.INT32, TensorProto.INT64]
303
+ data_type = (
304
+ numpy.float32
305
+ if elem_type == TensorProto.FLOAT
306
+ else (numpy.int64 if elem_type == TensorProto.INT64 else numpy.int32)
307
+ )
308
+ data = numpy.ones(shape, dtype=data_type)
309
+ dummy_inputs[graph_input.name] = data
310
+
311
+ all_inputs = [dummy_inputs for _ in range(samples)]
312
+ return all_inputs
313
+
314
+
315
+ def create_longformer_inputs(onnx_model, batch_size, sequence_length, global_length, samples):
316
+ """Create dummy inputs for Longformer model.
317
+
318
+ Args:
319
+ onnx_model (OnnxModel): ONNX model
320
+ batch_size (int): batch size
321
+ sequence_length (int): sequence length
322
+ global_length (int): number of global tokens
323
+ samples (int): number of samples
324
+
325
+ Raises:
326
+ RuntimeError: symbolic is not supported. Use the tool convert_longformer_to_onnx.py to export ONNX model instead.
327
+
328
+ Returns:
329
+ List[Dict]: list of inputs
330
+ """
331
+ symbols = {"batch_size": batch_size, "sequence_length": sequence_length}
332
+
333
+ dummy_inputs = {}
334
+ for graph_input in onnx_model.get_graph_inputs_excluding_initializers():
335
+ shape = get_shape_from_type_proto(graph_input.type)
336
+ for i, dim in enumerate(shape):
337
+ if isinstance(dim, str):
338
+ if dim not in symbols:
339
+ raise RuntimeError(f"symbol is not supported: {dim}")
340
+ else:
341
+ shape[i] = symbols[dim]
342
+
343
+ elem_type = graph_input.type.tensor_type.elem_type
344
+ assert elem_type in [TensorProto.FLOAT, TensorProto.INT32, TensorProto.INT64]
345
+ data_type = (
346
+ numpy.float32
347
+ if elem_type == TensorProto.FLOAT
348
+ else (numpy.int64 if elem_type == TensorProto.INT64 else numpy.int32)
349
+ )
350
+
351
+ if "global" in graph_input.name:
352
+ data = numpy.zeros(shape, dtype=data_type)
353
+ data[:, :global_length] = 1
354
+ else:
355
+ data = numpy.ones(shape, dtype=data_type)
356
+ dummy_inputs[graph_input.name] = data
357
+
358
+ all_inputs = [dummy_inputs for _ in range(samples)]
359
+ return all_inputs
360
+
361
+
362
+ def run(args):
363
+ num_threads = args.thread_num if args.thread_num > 0 else psutil.cpu_count(logical=False)
364
+
365
+ # Set OMP environment variable before importing onnxruntime. Needed for cpu only, and no impact for onnxruntime-gpu package.
366
+ if "OMP_NUM_THREADS" not in os.environ:
367
+ os.environ["OMP_NUM_THREADS"] = str(num_threads)
368
+
369
+ from onnx import load # noqa: PLC0415
370
+ from onnx_model import OnnxModel # noqa: PLC0415
371
+
372
+ onnx_model = OnnxModel(load(args.model))
373
+
374
+ all_inputs = None
375
+ if args.dummy_inputs == "bert":
376
+ all_inputs = create_bert_inputs(
377
+ onnx_model,
378
+ args.batch_size,
379
+ args.sequence_length,
380
+ args.samples,
381
+ args.input_ids_name,
382
+ args.segment_ids_name,
383
+ args.input_mask_name,
384
+ )
385
+ elif args.dummy_inputs == "gpt2":
386
+ all_inputs = create_gpt2_inputs(
387
+ onnx_model,
388
+ args.batch_size,
389
+ args.sequence_length,
390
+ args.past_sequence_length,
391
+ args.samples,
392
+ )
393
+ elif args.dummy_inputs == "longformer":
394
+ all_inputs = create_longformer_inputs(
395
+ onnx_model,
396
+ args.batch_size,
397
+ args.sequence_length,
398
+ args.global_length,
399
+ args.samples,
400
+ )
401
+ else: # default
402
+ all_inputs = create_dummy_inputs(onnx_model, args.batch_size, args.sequence_length, args.samples)
403
+
404
+ profile_file = run_profile(
405
+ args.model,
406
+ args.use_gpu,
407
+ args.provider,
408
+ args.basic_optimization,
409
+ args.thread_num,
410
+ all_inputs,
411
+ )
412
+
413
+ return profile_file
414
+
415
+
416
+ if __name__ == "__main__":
417
+ arguments = parse_arguments()
418
+ print("Arguments", arguments)
419
+
420
+ from benchmark_helper import setup_logger
421
+
422
+ setup_logger(arguments.verbose)
423
+
424
+ if not arguments.input:
425
+ assert arguments.model, "requires either --model to run profiling or --input to read profiling results"
426
+ profile_file = run(arguments)
427
+ else:
428
+ profile_file = arguments.input
429
+ from profile_result_processor import process_results
430
+
431
+ results = process_results(profile_file, arguments)
432
+
433
+ for line in results:
434
+ print(line)
@@ -0,0 +1,76 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License. See License.txt in the project root for
4
+ # license information.
5
+ # --------------------------------------------------------------------------
6
+
7
+ import logging
8
+ import os
9
+
10
+ import onnx
11
+ import torch
12
+ from transformers.modeling_utils import Conv1D
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ def _conv1d_to_linear(module):
18
+ in_size, out_size = module.weight.shape
19
+ linear = torch.nn.Linear(in_size, out_size)
20
+ linear.weight.data = module.weight.data.T.contiguous()
21
+ linear.bias.data = module.bias.data
22
+ return linear
23
+
24
+
25
+ def conv1d_to_linear(model):
26
+ """in-place
27
+ This is for Dynamic Quantization, as Conv1D is not recognized by PyTorch, convert it to nn.Linear
28
+ """
29
+ logger.debug("replace Conv1D with Linear")
30
+ for name in list(model._modules):
31
+ module = model._modules[name]
32
+ if isinstance(module, Conv1D):
33
+ linear = _conv1d_to_linear(module)
34
+ model._modules[name] = linear
35
+ else:
36
+ conv1d_to_linear(module)
37
+
38
+
39
+ def _get_size_of_pytorch_model(model):
40
+ torch.save(model.state_dict(), "temp.p")
41
+ size = os.path.getsize("temp.p") / (1024 * 1024)
42
+ os.remove("temp.p")
43
+ return size
44
+
45
+
46
+ class QuantizeHelper:
47
+ @staticmethod
48
+ def quantize_torch_model(model, dtype=torch.qint8):
49
+ """
50
+ Usage: model = quantize_model(model)
51
+
52
+ TODO: mix of in-place and return, but results are different
53
+ """
54
+ conv1d_to_linear(model)
55
+ quantized_model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=dtype)
56
+ logger.info(f"Size of full precision Torch model(MB):{_get_size_of_pytorch_model(model)}")
57
+ logger.info(f"Size of quantized Torch model(MB):{_get_size_of_pytorch_model(quantized_model)}")
58
+ return quantized_model
59
+
60
+ @staticmethod
61
+ def quantize_onnx_model(onnx_model_path, quantized_model_path, use_external_data_format=False):
62
+ from pathlib import Path # noqa: PLC0415
63
+
64
+ from onnxruntime.quantization import quantize_dynamic # noqa: PLC0415
65
+
66
+ Path(quantized_model_path).parent.mkdir(parents=True, exist_ok=True)
67
+ logger.info(f"Size of full precision ONNX model(MB):{os.path.getsize(onnx_model_path) / (1024 * 1024)}")
68
+ quantize_dynamic(
69
+ onnx_model_path,
70
+ quantized_model_path,
71
+ use_external_data_format=use_external_data_format,
72
+ extra_options={"DefaultTensorType": onnx.TensorProto.FLOAT},
73
+ )
74
+ logger.info(f"quantized model saved to:{quantized_model_path}")
75
+ # TODO: inlcude external data in total model size.
76
+ logger.info(f"Size of quantized ONNX model(MB):{os.path.getsize(quantized_model_path) / (1024 * 1024)}")
@@ -0,0 +1,121 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+
6
+ import logging
7
+ import os
8
+ import sys
9
+
10
+ # In ORT Package the symbolic_shape_infer.py is in ../tools
11
+ file_path = os.path.dirname(__file__)
12
+ if os.path.exists(os.path.join(file_path, "../tools/symbolic_shape_infer.py")):
13
+ sys.path.append(os.path.join(file_path, "../tools"))
14
+ else:
15
+ sys.path.append(os.path.join(file_path, ".."))
16
+
17
+ from symbolic_shape_infer import SymbolicShapeInference, get_shape_from_type_proto, sympy # noqa: E402
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ class SymbolicShapeInferenceHelper(SymbolicShapeInference):
23
+ def __init__(self, model, verbose=0, int_max=2**31 - 1, auto_merge=True, guess_output_rank=False):
24
+ super().__init__(int_max, auto_merge, guess_output_rank, verbose)
25
+ self.model_ = model
26
+ self.all_shapes_inferred_: bool = False
27
+ self.is_inferred_: bool = False
28
+ self.dynamic_axis_mapping_: dict[str, int] = {}
29
+
30
+ def infer(self, dynamic_axis_mapping: dict[str, int], max_runs: int = 200):
31
+ """Run shape inference, and try replace dynamic axis from string to integer when mapping is provided.
32
+
33
+ Args:
34
+ dynamic_axis_mapping (_type_): a dictionary with name of dynamic axis as key, like {"batch_size" : 4}
35
+ max_runs (int, optional): limit maximum number of runs to avoid infinite loop. Defaults to 200.
36
+
37
+ Returns:
38
+ bool: whether all shapes has been inferred or not.
39
+ """
40
+ assert dynamic_axis_mapping is not None
41
+
42
+ if self.is_inferred_ and self.dynamic_axis_mapping_ == dynamic_axis_mapping:
43
+ return self.all_shapes_inferred_
44
+
45
+ self.dynamic_axis_mapping_ = dynamic_axis_mapping
46
+
47
+ self._preprocess(self.model_)
48
+
49
+ count = 0
50
+ while self.run_:
51
+ logger.debug(f"shape infer run {count}")
52
+ self.all_shapes_inferred_ = self._infer_impl()
53
+ count += 1
54
+ if max_runs > 0 and count >= max_runs:
55
+ break
56
+
57
+ self.is_inferred_ = True
58
+ return self.all_shapes_inferred_
59
+
60
+ def _get_sympy_shape(self, node, idx):
61
+ """Override it to ensure shape inference by giving the actual value of dynamic axis."""
62
+ sympy_shape = []
63
+
64
+ shape = self._get_shape(node, idx)
65
+ if shape:
66
+ for dim in shape:
67
+ if isinstance(dim, str):
68
+ if dim in self.dynamic_axis_mapping_:
69
+ sympy_shape.append(self.dynamic_axis_mapping_[dim])
70
+ elif dim in self.symbolic_dims_:
71
+ sympy_shape.append(self.symbolic_dims_[dim])
72
+ else:
73
+ sympy_shape.append(sympy.Symbol(dim, integer=True))
74
+ else:
75
+ assert dim is not None
76
+ sympy_shape.append(dim)
77
+ return sympy_shape
78
+
79
+ def get_edge_shape(self, edge):
80
+ """Get shape of an edge.
81
+
82
+ Args:
83
+ edge (str): name of edge
84
+
85
+ Returns:
86
+ Optional[List[int]]: the shape, or None if shape is unknown
87
+ """
88
+ assert self.all_shapes_inferred_
89
+ if edge not in self.known_vi_:
90
+ print("Cannot retrieve the shape of " + str(edge))
91
+ return None
92
+
93
+ type_proto = self.known_vi_[edge].type
94
+ shape = get_shape_from_type_proto(type_proto)
95
+
96
+ if shape is not None:
97
+ for i, dim in enumerate(shape):
98
+ if isinstance(dim, str) and dim in self.dynamic_axis_mapping_:
99
+ shape[i] = self.dynamic_axis_mapping_[dim]
100
+
101
+ return shape
102
+
103
+ def compare_shape(self, edge, edge_other):
104
+ """Compare shape of two edges.
105
+
106
+ Args:
107
+ edge (str): name of edge
108
+ edge_other (str): name of another edge
109
+
110
+ Raises:
111
+ Exception: At least one shape is missed for edges to compare
112
+
113
+ Returns:
114
+ bool: whether the shape is same or not
115
+ """
116
+ assert self.all_shapes_inferred_
117
+ shape = self.get_edge_shape(edge)
118
+ shape_other = self.get_edge_shape(edge_other)
119
+ if shape is None or shape_other is None:
120
+ raise Exception("At least one shape is missed for edges to compare")
121
+ return shape == shape_other