onnxruntime-directml 1.20.0__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (305) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6508 -0
  4. onnxruntime/__init__.py +78 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +174 -0
  7. onnxruntime/backend/backend_rep.py +53 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  13. onnxruntime/capi/onnxruntime.dll +0 -0
  14. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  15. onnxruntime/capi/onnxruntime_inference_collection.py +1108 -0
  16. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  17. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  18. onnxruntime/capi/onnxruntime_validation.py +150 -0
  19. onnxruntime/capi/version_info.py +2 -0
  20. onnxruntime/datasets/__init__.py +17 -0
  21. onnxruntime/datasets/logreg_iris.onnx +0 -0
  22. onnxruntime/datasets/mul_1.onnx +0 -0
  23. onnxruntime/datasets/sigmoid.onnx +13 -0
  24. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  27. onnxruntime/quantization/__init__.py +16 -0
  28. onnxruntime/quantization/base_quantizer.py +532 -0
  29. onnxruntime/quantization/calibrate.py +1245 -0
  30. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  31. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  32. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  33. onnxruntime/quantization/execution_providers/qnn/preprocess.py +307 -0
  34. onnxruntime/quantization/execution_providers/qnn/quant_config.py +387 -0
  35. onnxruntime/quantization/fusions/__init__.py +3 -0
  36. onnxruntime/quantization/fusions/fusion.py +311 -0
  37. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  38. onnxruntime/quantization/fusions/fusion_layernorm.py +135 -0
  39. onnxruntime/quantization/matmul_4bits_quantizer.py +1480 -0
  40. onnxruntime/quantization/matmul_bnb4_quantizer.py +240 -0
  41. onnxruntime/quantization/onnx_model.py +580 -0
  42. onnxruntime/quantization/onnx_quantizer.py +1008 -0
  43. onnxruntime/quantization/operators/__init__.py +2 -0
  44. onnxruntime/quantization/operators/activation.py +119 -0
  45. onnxruntime/quantization/operators/argmax.py +18 -0
  46. onnxruntime/quantization/operators/attention.py +73 -0
  47. onnxruntime/quantization/operators/base_operator.py +26 -0
  48. onnxruntime/quantization/operators/binary_op.py +72 -0
  49. onnxruntime/quantization/operators/concat.py +62 -0
  50. onnxruntime/quantization/operators/conv.py +258 -0
  51. onnxruntime/quantization/operators/direct_q8.py +78 -0
  52. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  53. onnxruntime/quantization/operators/gather.py +64 -0
  54. onnxruntime/quantization/operators/gavgpool.py +62 -0
  55. onnxruntime/quantization/operators/gemm.py +166 -0
  56. onnxruntime/quantization/operators/lstm.py +117 -0
  57. onnxruntime/quantization/operators/matmul.py +231 -0
  58. onnxruntime/quantization/operators/maxpool.py +34 -0
  59. onnxruntime/quantization/operators/norm.py +40 -0
  60. onnxruntime/quantization/operators/pad.py +100 -0
  61. onnxruntime/quantization/operators/pooling.py +67 -0
  62. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  63. onnxruntime/quantization/operators/resize.py +34 -0
  64. onnxruntime/quantization/operators/softmax.py +74 -0
  65. onnxruntime/quantization/operators/split.py +63 -0
  66. onnxruntime/quantization/operators/where.py +87 -0
  67. onnxruntime/quantization/preprocess.py +141 -0
  68. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  69. onnxruntime/quantization/qdq_quantizer.py +1187 -0
  70. onnxruntime/quantization/quant_utils.py +891 -0
  71. onnxruntime/quantization/quantize.py +748 -0
  72. onnxruntime/quantization/registry.py +106 -0
  73. onnxruntime/quantization/shape_inference.py +187 -0
  74. onnxruntime/quantization/tensor_quant_overrides.py +516 -0
  75. onnxruntime/tools/__init__.py +10 -0
  76. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  77. onnxruntime/tools/convert_onnx_models_to_ort.py +377 -0
  78. onnxruntime/tools/file_utils.py +46 -0
  79. onnxruntime/tools/logger.py +11 -0
  80. onnxruntime/tools/make_dynamic_shape_fixed.py +72 -0
  81. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  82. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +33 -0
  83. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  84. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  85. onnxruntime/tools/mobile_helpers/usability_checker.py +739 -0
  86. onnxruntime/tools/offline_tuning.py +169 -0
  87. onnxruntime/tools/onnx_model_utils.py +413 -0
  88. onnxruntime/tools/onnx_randomizer.py +85 -0
  89. onnxruntime/tools/onnxruntime_test.py +164 -0
  90. onnxruntime/tools/optimize_onnx_model.py +55 -0
  91. onnxruntime/tools/ort_format_model/__init__.py +25 -0
  92. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +663 -0
  93. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  94. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  95. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  96. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  97. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  98. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  99. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  100. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  140. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  141. onnxruntime/tools/ort_format_model/types.py +84 -0
  142. onnxruntime/tools/ort_format_model/utils.py +62 -0
  143. onnxruntime/tools/pytorch_export_contrib_ops.py +108 -0
  144. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  145. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  146. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  147. onnxruntime/tools/reduced_build_config_parser.py +202 -0
  148. onnxruntime/tools/symbolic_shape_infer.py +3016 -0
  149. onnxruntime/tools/update_onnx_opset.py +31 -0
  150. onnxruntime/transformers/__init__.py +8 -0
  151. onnxruntime/transformers/affinity_helper.py +40 -0
  152. onnxruntime/transformers/benchmark.py +944 -0
  153. onnxruntime/transformers/benchmark_helper.py +646 -0
  154. onnxruntime/transformers/bert_perf_test.py +634 -0
  155. onnxruntime/transformers/bert_test_data.py +642 -0
  156. onnxruntime/transformers/compare_bert_results.py +246 -0
  157. onnxruntime/transformers/constants.py +47 -0
  158. onnxruntime/transformers/convert_generation.py +3124 -0
  159. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  160. onnxruntime/transformers/convert_to_packing_mode.py +387 -0
  161. onnxruntime/transformers/dynamo_onnx_helper.py +104 -0
  162. onnxruntime/transformers/float16.py +501 -0
  163. onnxruntime/transformers/fusion_attention.py +1235 -0
  164. onnxruntime/transformers/fusion_attention_clip.py +257 -0
  165. onnxruntime/transformers/fusion_attention_sam2.py +534 -0
  166. onnxruntime/transformers/fusion_attention_unet.py +1304 -0
  167. onnxruntime/transformers/fusion_attention_vae.py +301 -0
  168. onnxruntime/transformers/fusion_bart_attention.py +640 -0
  169. onnxruntime/transformers/fusion_base.py +137 -0
  170. onnxruntime/transformers/fusion_bias_add.py +58 -0
  171. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  172. onnxruntime/transformers/fusion_biassplitgelu.py +111 -0
  173. onnxruntime/transformers/fusion_conformer_attention.py +143 -0
  174. onnxruntime/transformers/fusion_embedlayer.py +811 -0
  175. onnxruntime/transformers/fusion_fastgelu.py +360 -0
  176. onnxruntime/transformers/fusion_gelu.py +259 -0
  177. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  178. onnxruntime/transformers/fusion_gemmfastgelu.py +122 -0
  179. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  180. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  181. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  182. onnxruntime/transformers/fusion_group_norm.py +179 -0
  183. onnxruntime/transformers/fusion_layernorm.py +465 -0
  184. onnxruntime/transformers/fusion_nhwc_conv.py +100 -0
  185. onnxruntime/transformers/fusion_options.py +340 -0
  186. onnxruntime/transformers/fusion_qordered_attention.py +421 -0
  187. onnxruntime/transformers/fusion_qordered_gelu.py +119 -0
  188. onnxruntime/transformers/fusion_qordered_layernorm.py +123 -0
  189. onnxruntime/transformers/fusion_qordered_matmul.py +217 -0
  190. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  191. onnxruntime/transformers/fusion_reshape.py +173 -0
  192. onnxruntime/transformers/fusion_rotary_attention.py +1592 -0
  193. onnxruntime/transformers/fusion_shape.py +110 -0
  194. onnxruntime/transformers/fusion_simplified_layernorm.py +159 -0
  195. onnxruntime/transformers/fusion_skip_group_norm.py +255 -0
  196. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  197. onnxruntime/transformers/fusion_transpose.py +168 -0
  198. onnxruntime/transformers/fusion_utils.py +307 -0
  199. onnxruntime/transformers/huggingface_models.py +167 -0
  200. onnxruntime/transformers/import_utils.py +20 -0
  201. onnxruntime/transformers/io_binding_helper.py +442 -0
  202. onnxruntime/transformers/large_model_exporter.py +395 -0
  203. onnxruntime/transformers/machine_info.py +221 -0
  204. onnxruntime/transformers/metrics.py +164 -0
  205. onnxruntime/transformers/models/bart/__init__.py +12 -0
  206. onnxruntime/transformers/models/bart/export.py +98 -0
  207. onnxruntime/transformers/models/bert/__init__.py +12 -0
  208. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  209. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  210. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  211. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +561 -0
  212. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1032 -0
  213. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  214. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  215. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  216. onnxruntime/transformers/models/llama/__init__.py +12 -0
  217. onnxruntime/transformers/models/llama/benchmark.py +703 -0
  218. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  219. onnxruntime/transformers/models/llama/benchmark_e2e.py +606 -0
  220. onnxruntime/transformers/models/llama/convert_to_onnx.py +1027 -0
  221. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  222. onnxruntime/transformers/models/llama/llama_inputs.py +503 -0
  223. onnxruntime/transformers/models/llama/llama_parity.py +309 -0
  224. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  225. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  226. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  227. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  228. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  229. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  230. onnxruntime/transformers/models/longformer/longformer_helper.py +77 -0
  231. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  232. onnxruntime/transformers/models/phi2/convert_to_onnx.py +576 -0
  233. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  234. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  235. onnxruntime/transformers/models/sam2/benchmark_sam2.py +625 -0
  236. onnxruntime/transformers/models/sam2/convert_to_onnx.py +260 -0
  237. onnxruntime/transformers/models/sam2/image_decoder.py +273 -0
  238. onnxruntime/transformers/models/sam2/image_encoder.py +186 -0
  239. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  240. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  241. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  242. onnxruntime/transformers/models/sam2/sam2_demo.py +322 -0
  243. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +280 -0
  244. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  245. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  246. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1429 -0
  247. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  248. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +102 -0
  249. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +268 -0
  250. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  251. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1319 -0
  252. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1181 -0
  253. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +296 -0
  254. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +388 -0
  255. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  256. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  257. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  258. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +350 -0
  259. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  260. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  261. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  262. onnxruntime/transformers/models/t5/__init__.py +12 -0
  263. onnxruntime/transformers/models/t5/convert_to_onnx.py +278 -0
  264. onnxruntime/transformers/models/t5/past_helper.py +150 -0
  265. onnxruntime/transformers/models/t5/t5_decoder.py +438 -0
  266. onnxruntime/transformers/models/t5/t5_encoder.py +171 -0
  267. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +299 -0
  268. onnxruntime/transformers/models/t5/t5_helper.py +272 -0
  269. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  270. onnxruntime/transformers/models/whisper/benchmark.py +610 -0
  271. onnxruntime/transformers/models/whisper/benchmark_all.py +528 -0
  272. onnxruntime/transformers/models/whisper/convert_to_onnx.py +536 -0
  273. onnxruntime/transformers/models/whisper/whisper_chain.py +329 -0
  274. onnxruntime/transformers/models/whisper/whisper_decoder.py +402 -0
  275. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  276. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +306 -0
  277. onnxruntime/transformers/models/whisper/whisper_helper.py +524 -0
  278. onnxruntime/transformers/models/whisper/whisper_openai_helper.py +84 -0
  279. onnxruntime/transformers/onnx_exporter.py +717 -0
  280. onnxruntime/transformers/onnx_model.py +1569 -0
  281. onnxruntime/transformers/onnx_model_bart.py +142 -0
  282. onnxruntime/transformers/onnx_model_bert.py +481 -0
  283. onnxruntime/transformers/onnx_model_bert_keras.py +475 -0
  284. onnxruntime/transformers/onnx_model_bert_tf.py +589 -0
  285. onnxruntime/transformers/onnx_model_clip.py +40 -0
  286. onnxruntime/transformers/onnx_model_conformer.py +33 -0
  287. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  288. onnxruntime/transformers/onnx_model_phi.py +930 -0
  289. onnxruntime/transformers/onnx_model_sam2.py +138 -0
  290. onnxruntime/transformers/onnx_model_t5.py +791 -0
  291. onnxruntime/transformers/onnx_model_tnlr.py +227 -0
  292. onnxruntime/transformers/onnx_model_unet.py +259 -0
  293. onnxruntime/transformers/onnx_model_vae.py +43 -0
  294. onnxruntime/transformers/onnx_utils.py +55 -0
  295. onnxruntime/transformers/optimizer.py +612 -0
  296. onnxruntime/transformers/profiler.py +725 -0
  297. onnxruntime/transformers/quantize_helper.py +76 -0
  298. onnxruntime/transformers/shape_infer_helper.py +122 -0
  299. onnxruntime/transformers/shape_optimizer.py +401 -0
  300. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  301. onnxruntime_directml-1.20.0.dist-info/METADATA +187 -0
  302. onnxruntime_directml-1.20.0.dist-info/RECORD +305 -0
  303. onnxruntime_directml-1.20.0.dist-info/WHEEL +5 -0
  304. onnxruntime_directml-1.20.0.dist-info/entry_points.txt +2 -0
  305. onnxruntime_directml-1.20.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,71 @@
1
+ # automatically generated by the FlatBuffers compiler, do not modify
2
+
3
+ # namespace: fbs
4
+
5
+ import flatbuffers
6
+ from flatbuffers.compat import import_numpy
7
+ np = import_numpy()
8
+
9
+ class TensorTypeAndShape(object):
10
+ __slots__ = ['_tab']
11
+
12
+ @classmethod
13
+ def GetRootAs(cls, buf, offset=0):
14
+ n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
15
+ x = TensorTypeAndShape()
16
+ x.Init(buf, n + offset)
17
+ return x
18
+
19
+ @classmethod
20
+ def GetRootAsTensorTypeAndShape(cls, buf, offset=0):
21
+ """This method is deprecated. Please switch to GetRootAs."""
22
+ return cls.GetRootAs(buf, offset)
23
+ @classmethod
24
+ def TensorTypeAndShapeBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
25
+ return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x4F\x52\x54\x4D", size_prefixed=size_prefixed)
26
+
27
+ # TensorTypeAndShape
28
+ def Init(self, buf, pos):
29
+ self._tab = flatbuffers.table.Table(buf, pos)
30
+
31
+ # TensorTypeAndShape
32
+ def ElemType(self):
33
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
34
+ if o != 0:
35
+ return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos)
36
+ return 0
37
+
38
+ # TensorTypeAndShape
39
+ def Shape(self):
40
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
41
+ if o != 0:
42
+ x = self._tab.Indirect(o + self._tab.Pos)
43
+ from ort_flatbuffers_py.fbs.Shape import Shape
44
+ obj = Shape()
45
+ obj.Init(self._tab.Bytes, x)
46
+ return obj
47
+ return None
48
+
49
+ def TensorTypeAndShapeStart(builder):
50
+ builder.StartObject(2)
51
+
52
+ def Start(builder):
53
+ TensorTypeAndShapeStart(builder)
54
+
55
+ def TensorTypeAndShapeAddElemType(builder, elemType):
56
+ builder.PrependInt32Slot(0, elemType, 0)
57
+
58
+ def AddElemType(builder, elemType):
59
+ TensorTypeAndShapeAddElemType(builder, elemType)
60
+
61
+ def TensorTypeAndShapeAddShape(builder, shape):
62
+ builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(shape), 0)
63
+
64
+ def AddShape(builder, shape):
65
+ TensorTypeAndShapeAddShape(builder, shape)
66
+
67
+ def TensorTypeAndShapeEnd(builder):
68
+ return builder.EndObject()
69
+
70
+ def End(builder):
71
+ return TensorTypeAndShapeEnd(builder)
@@ -0,0 +1,83 @@
1
+ # automatically generated by the FlatBuffers compiler, do not modify
2
+
3
+ # namespace: fbs
4
+
5
+ import flatbuffers
6
+ from flatbuffers.compat import import_numpy
7
+ np = import_numpy()
8
+
9
+ class TypeInfo(object):
10
+ __slots__ = ['_tab']
11
+
12
+ @classmethod
13
+ def GetRootAs(cls, buf, offset=0):
14
+ n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
15
+ x = TypeInfo()
16
+ x.Init(buf, n + offset)
17
+ return x
18
+
19
+ @classmethod
20
+ def GetRootAsTypeInfo(cls, buf, offset=0):
21
+ """This method is deprecated. Please switch to GetRootAs."""
22
+ return cls.GetRootAs(buf, offset)
23
+ @classmethod
24
+ def TypeInfoBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
25
+ return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x4F\x52\x54\x4D", size_prefixed=size_prefixed)
26
+
27
+ # TypeInfo
28
+ def Init(self, buf, pos):
29
+ self._tab = flatbuffers.table.Table(buf, pos)
30
+
31
+ # TypeInfo
32
+ def Denotation(self):
33
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
34
+ if o != 0:
35
+ return self._tab.String(o + self._tab.Pos)
36
+ return None
37
+
38
+ # TypeInfo
39
+ def ValueType(self):
40
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
41
+ if o != 0:
42
+ return self._tab.Get(flatbuffers.number_types.Uint8Flags, o + self._tab.Pos)
43
+ return 0
44
+
45
+ # TypeInfo
46
+ def Value(self):
47
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
48
+ if o != 0:
49
+ from flatbuffers.table import Table
50
+ obj = Table(bytearray(), 0)
51
+ self._tab.Union(obj, o)
52
+ return obj
53
+ return None
54
+
55
+ def TypeInfoStart(builder):
56
+ builder.StartObject(3)
57
+
58
+ def Start(builder):
59
+ TypeInfoStart(builder)
60
+
61
+ def TypeInfoAddDenotation(builder, denotation):
62
+ builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(denotation), 0)
63
+
64
+ def AddDenotation(builder, denotation):
65
+ TypeInfoAddDenotation(builder, denotation)
66
+
67
+ def TypeInfoAddValueType(builder, valueType):
68
+ builder.PrependUint8Slot(1, valueType, 0)
69
+
70
+ def AddValueType(builder, valueType):
71
+ TypeInfoAddValueType(builder, valueType)
72
+
73
+ def TypeInfoAddValue(builder, value):
74
+ builder.PrependUOffsetTRelativeSlot(2, flatbuffers.number_types.UOffsetTFlags.py_type(value), 0)
75
+
76
+ def AddValue(builder, value):
77
+ TypeInfoAddValue(builder, value)
78
+
79
+ def TypeInfoEnd(builder):
80
+ return builder.EndObject()
81
+
82
+ def End(builder):
83
+ return TypeInfoEnd(builder)
@@ -0,0 +1,9 @@
1
+ # automatically generated by the FlatBuffers compiler, do not modify
2
+
3
+ # namespace: fbs
4
+
5
+ class TypeInfoValue(object):
6
+ NONE = 0
7
+ tensor_type = 1
8
+ sequence_type = 2
9
+ map_type = 3
@@ -0,0 +1,84 @@
1
+ # automatically generated by the FlatBuffers compiler, do not modify
2
+
3
+ # namespace: fbs
4
+
5
+ import flatbuffers
6
+ from flatbuffers.compat import import_numpy
7
+ np = import_numpy()
8
+
9
+ class ValueInfo(object):
10
+ __slots__ = ['_tab']
11
+
12
+ @classmethod
13
+ def GetRootAs(cls, buf, offset=0):
14
+ n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
15
+ x = ValueInfo()
16
+ x.Init(buf, n + offset)
17
+ return x
18
+
19
+ @classmethod
20
+ def GetRootAsValueInfo(cls, buf, offset=0):
21
+ """This method is deprecated. Please switch to GetRootAs."""
22
+ return cls.GetRootAs(buf, offset)
23
+ @classmethod
24
+ def ValueInfoBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
25
+ return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x4F\x52\x54\x4D", size_prefixed=size_prefixed)
26
+
27
+ # ValueInfo
28
+ def Init(self, buf, pos):
29
+ self._tab = flatbuffers.table.Table(buf, pos)
30
+
31
+ # ValueInfo
32
+ def Name(self):
33
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
34
+ if o != 0:
35
+ return self._tab.String(o + self._tab.Pos)
36
+ return None
37
+
38
+ # ValueInfo
39
+ def DocString(self):
40
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
41
+ if o != 0:
42
+ return self._tab.String(o + self._tab.Pos)
43
+ return None
44
+
45
+ # ValueInfo
46
+ def Type(self):
47
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
48
+ if o != 0:
49
+ x = self._tab.Indirect(o + self._tab.Pos)
50
+ from ort_flatbuffers_py.fbs.TypeInfo import TypeInfo
51
+ obj = TypeInfo()
52
+ obj.Init(self._tab.Bytes, x)
53
+ return obj
54
+ return None
55
+
56
+ def ValueInfoStart(builder):
57
+ builder.StartObject(3)
58
+
59
+ def Start(builder):
60
+ ValueInfoStart(builder)
61
+
62
+ def ValueInfoAddName(builder, name):
63
+ builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(name), 0)
64
+
65
+ def AddName(builder, name):
66
+ ValueInfoAddName(builder, name)
67
+
68
+ def ValueInfoAddDocString(builder, docString):
69
+ builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(docString), 0)
70
+
71
+ def AddDocString(builder, docString):
72
+ ValueInfoAddDocString(builder, docString)
73
+
74
+ def ValueInfoAddType(builder, type):
75
+ builder.PrependUOffsetTRelativeSlot(2, flatbuffers.number_types.UOffsetTFlags.py_type(type), 0)
76
+
77
+ def AddType(builder, type):
78
+ ValueInfoAddType(builder, type)
79
+
80
+ def ValueInfoEnd(builder):
81
+ return builder.EndObject()
82
+
83
+ def End(builder):
84
+ return ValueInfoEnd(builder)
@@ -0,0 +1,6 @@
1
+ from os.path import dirname, basename, isfile, join, splitext
2
+ import glob
3
+ modules = glob.glob(join(dirname(__file__), "*.py"))
4
+ __all__ = [splitext(basename(f))[0] for f in modules if isfile(f) and not f.endswith('__init__.py')]
5
+
6
+ from . import *
@@ -0,0 +1,86 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ import ort_flatbuffers_py.fbs as fbs
5
+
6
+ from .operator_type_usage_processors import OperatorTypeUsageManager
7
+
8
+
9
+ class OrtFormatModelProcessor:
10
+ "Class to process an ORT format model and determine required operators and types."
11
+
12
+ def __init__(self, model_path: str, required_ops: dict, processors: OperatorTypeUsageManager):
13
+ """
14
+ Initialize ORT format model processor
15
+ :param model_path: Path to model to load
16
+ :param required_ops: Dictionary required operator information will be added to.
17
+ :param processors: Operator type usage processors which will be called for each matching Node.
18
+ """
19
+ self._required_ops = required_ops # dictionary of {domain: {opset:[operators]}}
20
+ self._file = open(model_path, "rb").read() # noqa: SIM115
21
+ self._buffer = bytearray(self._file)
22
+ if not fbs.InferenceSession.InferenceSession.InferenceSessionBufferHasIdentifier(self._buffer, 0):
23
+ raise RuntimeError(f"File does not appear to be a valid ORT format model: '{model_path}'")
24
+ self._model = fbs.InferenceSession.InferenceSession.GetRootAsInferenceSession(self._buffer, 0).Model()
25
+ self._op_type_processors = processors
26
+
27
+ @staticmethod
28
+ def _setup_type_info(graph: fbs.Graph, outer_scope_value_typeinfo={}): # noqa: B006
29
+ """
30
+ Setup the node args for this level of Graph.
31
+ We copy the current list which represents the outer scope values, and add the local node args to that
32
+ to create the valid list of values for the current Graph.
33
+ :param graph: Graph to create NodeArg list for
34
+ :param outer_scope_value_typeinfo: TypeInfo for outer scope values. Empty for the top-level graph in a model.
35
+ :return: Dictionary of NodeArg name to TypeInfo
36
+ """
37
+ value_name_to_typeinfo = outer_scope_value_typeinfo.copy()
38
+ for j in range(graph.NodeArgsLength()):
39
+ n = graph.NodeArgs(j)
40
+ value_name_to_typeinfo[n.Name()] = n.Type() # TypeInfo for this NodeArg's name
41
+
42
+ return value_name_to_typeinfo
43
+
44
+ def _add_required_op(self, domain: str, opset: int, op_type: str):
45
+ if domain not in self._required_ops:
46
+ self._required_ops[domain] = {opset: {op_type}}
47
+ elif opset not in self._required_ops[domain]:
48
+ self._required_ops[domain][opset] = {op_type}
49
+ else:
50
+ self._required_ops[domain][opset].add(op_type)
51
+
52
+ def _process_graph(self, graph: fbs.Graph, outer_scope_value_typeinfo: dict):
53
+ """
54
+ Process one level of the Graph, descending into any subgraphs when they are found
55
+ :param outer_scope_value_typeinfo: Outer scope NodeArg dictionary from ancestor graphs
56
+ """
57
+ # Merge the TypeInfo for all values in this level of the graph with the outer scope value TypeInfo.
58
+ value_name_to_typeinfo = OrtFormatModelProcessor._setup_type_info(graph, outer_scope_value_typeinfo)
59
+
60
+ for i in range(graph.NodesLength()):
61
+ node = graph.Nodes(i)
62
+
63
+ optype = node.OpType().decode()
64
+ domain = node.Domain().decode() or "ai.onnx" # empty domain defaults to ai.onnx
65
+
66
+ self._add_required_op(domain, node.SinceVersion(), optype)
67
+
68
+ if self._op_type_processors:
69
+ self._op_type_processors.process_node(node, value_name_to_typeinfo)
70
+
71
+ # Read all the attributes
72
+ for j in range(node.AttributesLength()):
73
+ attr = node.Attributes(j)
74
+ attr_type = attr.Type()
75
+ if attr_type == fbs.AttributeType.AttributeType.GRAPH:
76
+ self._process_graph(attr.G(), value_name_to_typeinfo)
77
+ elif attr_type == fbs.AttributeType.AttributeType.GRAPHS:
78
+ # the ONNX spec doesn't currently define any operators that have multiple graphs in an attribute
79
+ # so entering this 'elif' isn't currently possible
80
+ for k in range(attr.GraphsLength()):
81
+ self._process_graph(attr.Graphs(k), value_name_to_typeinfo)
82
+
83
+ def process(self):
84
+ graph = self._model.Graph()
85
+ outer_scope_value_typeinfo = {} # no outer scope values for the main graph
86
+ self._process_graph(graph, outer_scope_value_typeinfo)
@@ -0,0 +1,84 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ import ort_flatbuffers_py.fbs as fbs
5
+
6
+
7
+ class FbsTypeInfo:
8
+ "Class to provide conversion between ORT flatbuffers schema values and C++ types"
9
+ tensordatatype_to_string = { # noqa: RUF012
10
+ fbs.TensorDataType.TensorDataType.FLOAT: "float",
11
+ fbs.TensorDataType.TensorDataType.UINT8: "uint8_t",
12
+ fbs.TensorDataType.TensorDataType.INT8: "int8_t",
13
+ fbs.TensorDataType.TensorDataType.UINT16: "uint16_t",
14
+ fbs.TensorDataType.TensorDataType.INT16: "int16_t",
15
+ fbs.TensorDataType.TensorDataType.INT32: "int32_t",
16
+ fbs.TensorDataType.TensorDataType.INT64: "int64_t",
17
+ fbs.TensorDataType.TensorDataType.STRING: "std::string",
18
+ fbs.TensorDataType.TensorDataType.BOOL: "bool",
19
+ fbs.TensorDataType.TensorDataType.FLOAT16: "MLFloat16",
20
+ fbs.TensorDataType.TensorDataType.DOUBLE: "double",
21
+ fbs.TensorDataType.TensorDataType.UINT32: "uint32_t",
22
+ fbs.TensorDataType.TensorDataType.UINT64: "uint64_t",
23
+ # fbs.TensorDataType.TensorDataType.COMPLEX64: 'complex64 is not supported',
24
+ # fbs.TensorDataType.TensorDataType.COMPLEX128: 'complex128 is not supported',
25
+ fbs.TensorDataType.TensorDataType.BFLOAT16: "BFloat16",
26
+ fbs.TensorDataType.TensorDataType.FLOAT8E4M3FN: "Float8E4M3FN",
27
+ fbs.TensorDataType.TensorDataType.FLOAT8E4M3FNUZ: "Float8E4M3FNUZ",
28
+ fbs.TensorDataType.TensorDataType.FLOAT8E5M2: "Float8E5M2",
29
+ fbs.TensorDataType.TensorDataType.FLOAT8E5M2FNUZ: "Float8E5M2FNUZ",
30
+ }
31
+
32
+ @staticmethod
33
+ def typeinfo_to_str(type: fbs.TypeInfo):
34
+ value_type = type.ValueType()
35
+ value = type.Value()
36
+ type_str = "unknown"
37
+
38
+ if value_type == fbs.TypeInfoValue.TypeInfoValue.tensor_type:
39
+ tensor_type_and_shape = fbs.TensorTypeAndShape.TensorTypeAndShape()
40
+ tensor_type_and_shape.Init(value.Bytes, value.Pos)
41
+ elem_type = tensor_type_and_shape.ElemType()
42
+ type_str = FbsTypeInfo.tensordatatype_to_string[elem_type]
43
+
44
+ elif value_type == fbs.TypeInfoValue.TypeInfoValue.map_type:
45
+ map_type = fbs.MapType.MapType()
46
+ map_type.init(value.Bytes, value.Pos)
47
+ key_type = map_type.KeyType() # TensorDataType
48
+ key_type_str = FbsTypeInfo.tensordatatype_to_string[key_type]
49
+ value_type = map_type.ValueType() # TypeInfo
50
+ value_type_str = FbsTypeInfo.typeinfo_to_str(value_type)
51
+ type_str = f"std::map<{key_type_str},{value_type_str}>"
52
+
53
+ elif value_type == fbs.TypeInfoValue.TypeInfoValue.sequence_type:
54
+ sequence_type = fbs.SequenceType.SequenceType()
55
+ sequence_type.Init(value.Bytes, value.Pos)
56
+ elem_type = sequence_type.ElemType() # TypeInfo
57
+ elem_type_str = FbsTypeInfo.typeinfo_to_str(elem_type)
58
+ # TODO: Decide if we need to wrap the type in a std::vector. Issue is that the element type is internal
59
+ # to the onnxruntime::Tensor class so we're really returning the type inside the Tensor not vector<Tensor>.
60
+ # For now, return the element type (which will be the Tensor element type, or a map<A,B>) as
61
+ # an operator input or output will either be a sequence or a not, so we don't need to disambiguate
62
+ # between the two (i.e. we know if the returned value refers to the contents of a sequence, and can
63
+ # handle whether it's the element type of a Tensor in the sequence, or the map type in a sequence of maps
64
+ # due to this).
65
+ type_str = elem_type_str
66
+ else:
67
+ raise ValueError(f"Unknown or missing value type of {value_type}")
68
+
69
+ return type_str
70
+
71
+
72
+ def get_typeinfo(name: str, value_name_to_typeinfo: dict) -> fbs.TypeInfo:
73
+ "Lookup a name in a dictionary mapping value name to TypeInfo."
74
+ if name not in value_name_to_typeinfo:
75
+ raise RuntimeError("Missing TypeInfo entry for " + name)
76
+
77
+ return value_name_to_typeinfo[name] # TypeInfo object
78
+
79
+
80
+ def value_name_to_typestr(name: str, value_name_to_typeinfo: dict):
81
+ "Lookup TypeInfo for value name and convert to a string representing the C++ type."
82
+ type = get_typeinfo(name, value_name_to_typeinfo)
83
+ type_str = FbsTypeInfo.typeinfo_to_str(type)
84
+ return type_str
@@ -0,0 +1,62 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ import pathlib
5
+ import typing
6
+
7
+ from ..logger import get_logger
8
+ from .operator_type_usage_processors import OperatorTypeUsageManager
9
+ from .ort_model_processor import OrtFormatModelProcessor
10
+
11
+ log = get_logger("ort_format_model.utils")
12
+
13
+
14
+ def _extract_ops_and_types_from_ort_models(model_files: typing.Iterable[pathlib.Path], enable_type_reduction: bool):
15
+ required_ops = {}
16
+ op_type_usage_manager = OperatorTypeUsageManager() if enable_type_reduction else None
17
+
18
+ for model_file in model_files:
19
+ if not model_file.is_file():
20
+ raise ValueError(f"Path is not a file: '{model_file}'")
21
+ model_processor = OrtFormatModelProcessor(str(model_file), required_ops, op_type_usage_manager)
22
+ model_processor.process() # this updates required_ops and op_type_processors
23
+
24
+ return required_ops, op_type_usage_manager
25
+
26
+
27
+ def create_config_from_models(
28
+ model_files: typing.Iterable[pathlib.Path], output_file: pathlib.Path, enable_type_reduction: bool
29
+ ):
30
+ """
31
+ Create a configuration file with required operators and optionally required types.
32
+ :param model_files: Model files to use to generate the configuration file.
33
+ :param output_file: File to write configuration to.
34
+ :param enable_type_reduction: Include required type information for individual operators in the configuration.
35
+ """
36
+
37
+ required_ops, op_type_processors = _extract_ops_and_types_from_ort_models(model_files, enable_type_reduction)
38
+
39
+ output_file.parent.mkdir(parents=True, exist_ok=True)
40
+
41
+ with open(output_file, "w") as out:
42
+ out.write("# Generated from model/s:\n")
43
+ for model_file in sorted(model_files):
44
+ out.write(f"# - {model_file}\n")
45
+
46
+ for domain in sorted(required_ops.keys()):
47
+ for opset in sorted(required_ops[domain].keys()):
48
+ ops = required_ops[domain][opset]
49
+ if ops:
50
+ out.write(f"{domain};{opset};")
51
+ if enable_type_reduction:
52
+ # type string is empty if op hasn't been seen
53
+ entries = [
54
+ "{}{}".format(op, op_type_processors.get_config_entry(domain, op) or "")
55
+ for op in sorted(ops)
56
+ ]
57
+ else:
58
+ entries = sorted(ops)
59
+
60
+ out.write("{}\n".format(",".join(entries)))
61
+
62
+ log.info("Created config in %s", output_file)
@@ -0,0 +1,108 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """
5
+ Support for registering ONNX Runtime's built-in contrib ops with
6
+ PyTorch-ONNX exporter (torch.onnx.export).
7
+ """
8
+ import typing
9
+
10
+ try:
11
+ # TODO(justinchuby): Create a function to alert users when torch is not installed
12
+ import torch
13
+ except ModuleNotFoundError:
14
+ raise ModuleNotFoundError( # noqa: B904
15
+ "This module is only useful in combination with PyTorch. To install PyTorch see https://pytorch.org/."
16
+ )
17
+
18
+ from torch.onnx import symbolic_helper
19
+
20
+ _OPSET_VERSION = 1
21
+ _registered_ops: typing.AbstractSet[str] = set()
22
+
23
+
24
+ def _reg(symbolic_fn: typing.Callable):
25
+ name = f"::{symbolic_fn.__name__}"
26
+ torch.onnx.register_custom_op_symbolic(name, symbolic_fn, _OPSET_VERSION)
27
+ _registered_ops.add(name)
28
+
29
+
30
+ def register():
31
+ """Register ONNX Runtime's built-in contrib ops.
32
+
33
+ Should be run before torch.onnx.export().
34
+ """
35
+
36
+ def grid_sampler(g, input, grid, mode, padding_mode, align_corners):
37
+ # mode
38
+ # 'bilinear' : onnx::Constant[value={0}]
39
+ # 'nearest' : onnx::Constant[value={1}]
40
+ # 'bicubic' : onnx::Constant[value={2}]
41
+ # padding_mode
42
+ # 'zeros' : onnx::Constant[value={0}]
43
+ # 'border' : onnx::Constant[value={1}]
44
+ # 'reflection' : onnx::Constant[value={2}]
45
+ mode = symbolic_helper._maybe_get_const(mode, "i")
46
+ padding_mode = symbolic_helper._maybe_get_const(padding_mode, "i")
47
+ mode_str = ["bilinear", "nearest", "bicubic"][mode]
48
+ padding_mode_str = ["zeros", "border", "reflection"][padding_mode]
49
+ align_corners = int(symbolic_helper._maybe_get_const(align_corners, "b"))
50
+
51
+ # From opset v13 onward, the output shape can be specified with
52
+ # (N, C, H, W) (N, H_out, W_out, 2) => (N, C, H_out, W_out)
53
+ # input_shape = input.type().sizes()
54
+ # gird_shape = grid.type().sizes()
55
+ # output_shape = input_shape[:2] + gird_shape[1:3]
56
+ # g.op(...).setType(input.type().with_sizes(output_shape))
57
+
58
+ return g.op(
59
+ "com.microsoft::GridSample",
60
+ input,
61
+ grid,
62
+ mode_s=mode_str,
63
+ padding_mode_s=padding_mode_str,
64
+ align_corners_i=align_corners,
65
+ )
66
+
67
+ _reg(grid_sampler)
68
+
69
+ def inverse(g, self):
70
+ return g.op("com.microsoft::Inverse", self).setType(self.type())
71
+
72
+ _reg(inverse)
73
+
74
+ @torch.onnx.symbolic_helper.parse_args("v", "s")
75
+ def gelu(g, self: torch._C.Value, approximate: str = "none"):
76
+ # Use microsoft::Gelu for performance if possible. It only supports approximate == "none"
77
+ if approximate == "none":
78
+ return g.op("com.microsoft::Gelu", self).setType(self.type())
79
+ return torch.onnx.symbolic_opset9.gelu(g, self, approximate)
80
+
81
+ _reg(gelu)
82
+
83
+ def triu(g, self, diagonal):
84
+ return g.op("com.microsoft::Trilu", self, diagonal, upper_i=1).setType(self.type())
85
+
86
+ _reg(triu)
87
+
88
+ def tril(g, self, diagonal):
89
+ return g.op("com.microsoft::Trilu", self, diagonal, upper_i=0).setType(self.type())
90
+
91
+ _reg(tril)
92
+
93
+
94
+ def unregister():
95
+ """Unregister ONNX Runtime's built-in contrib ops."""
96
+ for name in _registered_ops:
97
+ try:
98
+ torch.onnx.unregister_custom_op_symbolic(name, _OPSET_VERSION)
99
+ except AttributeError:
100
+ # The symbolic_registry module was removed in PyTorch 1.13.
101
+ # We are importing it here for backwards compatibility
102
+ # because unregister_custom_op_symbolic is not available before PyTorch 1.12
103
+ from torch.onnx import symbolic_registry
104
+
105
+ namespace, kind = name.split("::")
106
+ for version in symbolic_helper._onnx_stable_opsets:
107
+ if version >= _OPSET_VERSION and symbolic_registry.is_registered_op(kind, namespace, version):
108
+ del symbolic_registry._registry[(namespace, version)][kind]