onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (322) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6121 -0
  4. onnxruntime/__init__.py +418 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +175 -0
  7. onnxruntime/backend/backend_rep.py +52 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/build_and_package_info.py +2 -0
  13. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  14. onnxruntime/capi/onnxruntime.dll +0 -0
  15. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  16. onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
  17. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  18. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  19. onnxruntime/capi/onnxruntime_validation.py +154 -0
  20. onnxruntime/capi/version_info.py +2 -0
  21. onnxruntime/datasets/__init__.py +18 -0
  22. onnxruntime/datasets/logreg_iris.onnx +0 -0
  23. onnxruntime/datasets/mul_1.onnx +0 -0
  24. onnxruntime/datasets/sigmoid.onnx +13 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  27. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  28. onnxruntime/quantization/__init__.py +19 -0
  29. onnxruntime/quantization/base_quantizer.py +529 -0
  30. onnxruntime/quantization/calibrate.py +1267 -0
  31. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  32. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  33. onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
  34. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  35. onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
  36. onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
  37. onnxruntime/quantization/fusions/__init__.py +4 -0
  38. onnxruntime/quantization/fusions/fusion.py +311 -0
  39. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  40. onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
  41. onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
  42. onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
  43. onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
  44. onnxruntime/quantization/neural_compressor/__init__.py +1 -0
  45. onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
  46. onnxruntime/quantization/neural_compressor/util.py +80 -0
  47. onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
  48. onnxruntime/quantization/onnx_model.py +600 -0
  49. onnxruntime/quantization/onnx_quantizer.py +1163 -0
  50. onnxruntime/quantization/operators/__init__.py +2 -0
  51. onnxruntime/quantization/operators/activation.py +119 -0
  52. onnxruntime/quantization/operators/argmax.py +18 -0
  53. onnxruntime/quantization/operators/attention.py +73 -0
  54. onnxruntime/quantization/operators/base_operator.py +26 -0
  55. onnxruntime/quantization/operators/binary_op.py +72 -0
  56. onnxruntime/quantization/operators/concat.py +62 -0
  57. onnxruntime/quantization/operators/conv.py +260 -0
  58. onnxruntime/quantization/operators/direct_q8.py +78 -0
  59. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  60. onnxruntime/quantization/operators/gather.py +64 -0
  61. onnxruntime/quantization/operators/gavgpool.py +62 -0
  62. onnxruntime/quantization/operators/gemm.py +172 -0
  63. onnxruntime/quantization/operators/lstm.py +121 -0
  64. onnxruntime/quantization/operators/matmul.py +231 -0
  65. onnxruntime/quantization/operators/maxpool.py +34 -0
  66. onnxruntime/quantization/operators/norm.py +40 -0
  67. onnxruntime/quantization/operators/pad.py +172 -0
  68. onnxruntime/quantization/operators/pooling.py +67 -0
  69. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  70. onnxruntime/quantization/operators/resize.py +34 -0
  71. onnxruntime/quantization/operators/softmax.py +74 -0
  72. onnxruntime/quantization/operators/split.py +63 -0
  73. onnxruntime/quantization/operators/where.py +87 -0
  74. onnxruntime/quantization/preprocess.py +141 -0
  75. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  76. onnxruntime/quantization/qdq_quantizer.py +1477 -0
  77. onnxruntime/quantization/quant_utils.py +1051 -0
  78. onnxruntime/quantization/quantize.py +953 -0
  79. onnxruntime/quantization/registry.py +110 -0
  80. onnxruntime/quantization/shape_inference.py +204 -0
  81. onnxruntime/quantization/static_quantize_runner.py +256 -0
  82. onnxruntime/quantization/tensor_quant_overrides.py +520 -0
  83. onnxruntime/tools/__init__.py +10 -0
  84. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  85. onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
  86. onnxruntime/tools/file_utils.py +47 -0
  87. onnxruntime/tools/logger.py +11 -0
  88. onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
  89. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  90. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
  91. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  92. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  93. onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
  94. onnxruntime/tools/offline_tuning.py +169 -0
  95. onnxruntime/tools/onnx_model_utils.py +416 -0
  96. onnxruntime/tools/onnx_randomizer.py +85 -0
  97. onnxruntime/tools/onnxruntime_test.py +164 -0
  98. onnxruntime/tools/optimize_onnx_model.py +56 -0
  99. onnxruntime/tools/ort_format_model/__init__.py +27 -0
  100. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  140. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  141. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  142. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  143. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  144. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  145. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  146. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  147. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  148. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  149. onnxruntime/tools/ort_format_model/types.py +85 -0
  150. onnxruntime/tools/ort_format_model/utils.py +61 -0
  151. onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
  152. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  153. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  154. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  155. onnxruntime/tools/qnn/add_trans_cast.py +292 -0
  156. onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
  157. onnxruntime/tools/qnn/preprocess.py +165 -0
  158. onnxruntime/tools/reduced_build_config_parser.py +203 -0
  159. onnxruntime/tools/remove_initializer_from_input.py +37 -0
  160. onnxruntime/tools/symbolic_shape_infer.py +3094 -0
  161. onnxruntime/tools/update_onnx_opset.py +31 -0
  162. onnxruntime/transformers/__init__.py +8 -0
  163. onnxruntime/transformers/affinity_helper.py +40 -0
  164. onnxruntime/transformers/benchmark.py +942 -0
  165. onnxruntime/transformers/benchmark_helper.py +643 -0
  166. onnxruntime/transformers/bert_perf_test.py +629 -0
  167. onnxruntime/transformers/bert_test_data.py +641 -0
  168. onnxruntime/transformers/compare_bert_results.py +256 -0
  169. onnxruntime/transformers/constants.py +47 -0
  170. onnxruntime/transformers/convert_generation.py +3605 -0
  171. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  172. onnxruntime/transformers/convert_to_packing_mode.py +385 -0
  173. onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
  174. onnxruntime/transformers/float16.py +501 -0
  175. onnxruntime/transformers/fusion_attention.py +1189 -0
  176. onnxruntime/transformers/fusion_attention_clip.py +340 -0
  177. onnxruntime/transformers/fusion_attention_sam2.py +533 -0
  178. onnxruntime/transformers/fusion_attention_unet.py +1307 -0
  179. onnxruntime/transformers/fusion_attention_vae.py +300 -0
  180. onnxruntime/transformers/fusion_bart_attention.py +435 -0
  181. onnxruntime/transformers/fusion_base.py +141 -0
  182. onnxruntime/transformers/fusion_bias_add.py +57 -0
  183. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  184. onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
  185. onnxruntime/transformers/fusion_conformer_attention.py +222 -0
  186. onnxruntime/transformers/fusion_constant_fold.py +144 -0
  187. onnxruntime/transformers/fusion_embedlayer.py +810 -0
  188. onnxruntime/transformers/fusion_fastgelu.py +492 -0
  189. onnxruntime/transformers/fusion_gelu.py +258 -0
  190. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  191. onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
  192. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  193. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  194. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  195. onnxruntime/transformers/fusion_group_norm.py +180 -0
  196. onnxruntime/transformers/fusion_layernorm.py +489 -0
  197. onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
  198. onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
  199. onnxruntime/transformers/fusion_options.py +340 -0
  200. onnxruntime/transformers/fusion_qordered_attention.py +420 -0
  201. onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
  202. onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
  203. onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
  204. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  205. onnxruntime/transformers/fusion_reshape.py +173 -0
  206. onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
  207. onnxruntime/transformers/fusion_shape.py +109 -0
  208. onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
  209. onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
  210. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  211. onnxruntime/transformers/fusion_transpose.py +167 -0
  212. onnxruntime/transformers/fusion_utils.py +321 -0
  213. onnxruntime/transformers/huggingface_models.py +74 -0
  214. onnxruntime/transformers/import_utils.py +20 -0
  215. onnxruntime/transformers/io_binding_helper.py +487 -0
  216. onnxruntime/transformers/large_model_exporter.py +395 -0
  217. onnxruntime/transformers/machine_info.py +230 -0
  218. onnxruntime/transformers/metrics.py +163 -0
  219. onnxruntime/transformers/models/bart/__init__.py +12 -0
  220. onnxruntime/transformers/models/bart/export.py +98 -0
  221. onnxruntime/transformers/models/bert/__init__.py +12 -0
  222. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  223. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  224. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  225. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
  226. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
  227. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  228. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  229. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  230. onnxruntime/transformers/models/llama/__init__.py +12 -0
  231. onnxruntime/transformers/models/llama/benchmark.py +700 -0
  232. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  233. onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
  234. onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
  235. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  236. onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
  237. onnxruntime/transformers/models/llama/llama_parity.py +343 -0
  238. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  239. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  240. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  241. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  242. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  243. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  244. onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
  245. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  246. onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
  247. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  248. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  249. onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
  250. onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
  251. onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
  252. onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
  253. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  254. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  255. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  256. onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
  257. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
  258. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  259. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  260. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
  261. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  262. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
  263. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
  264. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  265. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
  266. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
  267. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
  268. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
  269. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  270. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  271. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  272. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
  273. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  274. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  275. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  276. onnxruntime/transformers/models/t5/__init__.py +12 -0
  277. onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
  278. onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
  279. onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
  280. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
  281. onnxruntime/transformers/models/t5/t5_helper.py +302 -0
  282. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  283. onnxruntime/transformers/models/whisper/benchmark.py +585 -0
  284. onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
  285. onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
  286. onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
  287. onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
  288. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  289. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
  290. onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
  291. onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
  292. onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
  293. onnxruntime/transformers/onnx_exporter.py +719 -0
  294. onnxruntime/transformers/onnx_model.py +1636 -0
  295. onnxruntime/transformers/onnx_model_bart.py +141 -0
  296. onnxruntime/transformers/onnx_model_bert.py +488 -0
  297. onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
  298. onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
  299. onnxruntime/transformers/onnx_model_clip.py +42 -0
  300. onnxruntime/transformers/onnx_model_conformer.py +32 -0
  301. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  302. onnxruntime/transformers/onnx_model_mmdit.py +112 -0
  303. onnxruntime/transformers/onnx_model_phi.py +929 -0
  304. onnxruntime/transformers/onnx_model_sam2.py +137 -0
  305. onnxruntime/transformers/onnx_model_t5.py +985 -0
  306. onnxruntime/transformers/onnx_model_tnlr.py +226 -0
  307. onnxruntime/transformers/onnx_model_unet.py +258 -0
  308. onnxruntime/transformers/onnx_model_vae.py +42 -0
  309. onnxruntime/transformers/onnx_utils.py +55 -0
  310. onnxruntime/transformers/optimizer.py +620 -0
  311. onnxruntime/transformers/past_helper.py +149 -0
  312. onnxruntime/transformers/profile_result_processor.py +358 -0
  313. onnxruntime/transformers/profiler.py +434 -0
  314. onnxruntime/transformers/quantize_helper.py +76 -0
  315. onnxruntime/transformers/shape_infer_helper.py +121 -0
  316. onnxruntime/transformers/shape_optimizer.py +400 -0
  317. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  318. onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
  319. onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
  320. onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
  321. onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
  322. onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,418 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+ """
6
+ ONNX Runtime is a performance-focused scoring engine for Open Neural Network Exchange (ONNX) models.
7
+ For more information on ONNX Runtime, please see `aka.ms/onnxruntime <https://aka.ms/onnxruntime/>`_
8
+ or the `Github project <https://github.com/microsoft/onnxruntime/>`_.
9
+ """
10
+
11
+ import contextlib
12
+
13
+ __version__ = "1.24.1"
14
+ __author__ = "Microsoft"
15
+
16
+ # we need to do device version validation (for example to check Cuda version for an onnxruntime-training package).
17
+ # in order to know whether the onnxruntime package is for training it needs
18
+ # to do import onnxruntime.training.ortmodule first.
19
+ # onnxruntime.capi._pybind_state is required before import onnxruntime.training.ortmodule.
20
+ # however, import onnxruntime.capi._pybind_state will already raise an exception if a required Cuda version
21
+ # is not found.
22
+ # here we need to save the exception and continue with Cuda version validation in order to post
23
+ # meaningful messages to the user.
24
+ # the saved exception is raised after device version validation.
25
+ try:
26
+ from onnxruntime.capi._pybind_state import (
27
+ ExecutionMode, # noqa: F401
28
+ ExecutionOrder, # noqa: F401
29
+ GraphOptimizationLevel, # noqa: F401
30
+ LoraAdapter, # noqa: F401
31
+ ModelMetadata, # noqa: F401
32
+ NodeArg, # noqa: F401
33
+ OrtAllocatorType, # noqa: F401
34
+ OrtArenaCfg, # noqa: F401
35
+ OrtCompileApiFlags, # noqa: F401
36
+ OrtDeviceMemoryType, # noqa: F401
37
+ OrtEpAssignedNode, # noqa: F401
38
+ OrtEpAssignedSubgraph, # noqa: F401
39
+ OrtEpDevice, # noqa: F401
40
+ OrtExecutionProviderDevicePolicy, # noqa: F401
41
+ OrtExternalInitializerInfo, # noqa: F401
42
+ OrtHardwareDevice, # noqa: F401
43
+ OrtHardwareDeviceType, # noqa: F401
44
+ OrtMemoryInfo, # noqa: F401
45
+ OrtMemoryInfoDeviceType, # noqa: F401
46
+ OrtMemType, # noqa: F401
47
+ OrtSparseFormat, # noqa: F401
48
+ OrtSyncStream, # noqa: F401
49
+ RunOptions, # noqa: F401
50
+ SessionIOBinding, # noqa: F401
51
+ SessionOptions, # noqa: F401
52
+ create_and_register_allocator, # noqa: F401
53
+ create_and_register_allocator_v2, # noqa: F401
54
+ disable_telemetry_events, # noqa: F401
55
+ enable_telemetry_events, # noqa: F401
56
+ get_all_providers, # noqa: F401
57
+ get_available_providers, # noqa: F401
58
+ get_build_info, # noqa: F401
59
+ get_device, # noqa: F401
60
+ get_ep_devices, # noqa: F401
61
+ get_version_string, # noqa: F401
62
+ has_collective_ops, # noqa: F401
63
+ register_execution_provider_library, # noqa: F401
64
+ set_default_logger_severity, # noqa: F401
65
+ set_default_logger_verbosity, # noqa: F401
66
+ set_global_thread_pool_sizes, # noqa: F401
67
+ set_seed, # noqa: F401
68
+ unregister_execution_provider_library, # noqa: F401
69
+ )
70
+
71
+ import_capi_exception = None
72
+ except Exception as e:
73
+ import_capi_exception = e
74
+
75
+ from onnxruntime.capi import onnxruntime_validation
76
+
77
+ if import_capi_exception:
78
+ raise import_capi_exception
79
+
80
+ from onnxruntime.capi.onnxruntime_inference_collection import (
81
+ AdapterFormat, # noqa: F401
82
+ InferenceSession, # noqa: F401
83
+ IOBinding, # noqa: F401
84
+ ModelCompiler, # noqa: F401
85
+ OrtDevice, # noqa: F401
86
+ OrtValue, # noqa: F401
87
+ SparseTensor, # noqa: F401
88
+ copy_tensors, # noqa: F401
89
+ )
90
+
91
+ # TODO: thiagofc: Temporary experimental namespace for new PyTorch front-end
92
+ try: # noqa: SIM105
93
+ from . import experimental # noqa: F401
94
+ except ImportError:
95
+ pass
96
+
97
+
98
+ package_name, version, cuda_version = onnxruntime_validation.get_package_name_and_version_info()
99
+
100
+ if version:
101
+ __version__ = version
102
+
103
+ onnxruntime_validation.check_distro_info()
104
+
105
+
106
+ def _get_package_version(package_name: str):
107
+ from importlib.metadata import PackageNotFoundError, version # noqa: PLC0415
108
+
109
+ try:
110
+ package_version = version(package_name)
111
+ except PackageNotFoundError:
112
+ package_version = None
113
+ return package_version
114
+
115
+
116
+ def _get_package_root(package_name: str, directory_name: str | None = None):
117
+ from importlib.metadata import PackageNotFoundError, distribution # noqa: PLC0415
118
+
119
+ root_directory_name = directory_name or package_name
120
+ try:
121
+ dist = distribution(package_name)
122
+ files = dist.files or []
123
+
124
+ for file in files:
125
+ if file.name.endswith("__init__.py") and root_directory_name in file.parts:
126
+ return file.locate().parent
127
+
128
+ # Fallback to the first __init__.py
129
+ if not directory_name:
130
+ for file in files:
131
+ if file.name.endswith("__init__.py"):
132
+ return file.locate().parent
133
+ except PackageNotFoundError:
134
+ # package not found, do nothing
135
+ pass
136
+
137
+ return None
138
+
139
+
140
+ def _extract_cuda_major_version(version_str: str) -> str:
141
+ """Extract CUDA major version from version string (e.g., '12.1' -> '12').
142
+
143
+ Args:
144
+ version_str: CUDA version string to parse
145
+
146
+ Returns:
147
+ Major version as string, or "12" if parsing fails
148
+ """
149
+ return version_str.split(".")[0] if version_str else "12"
150
+
151
+
152
+ def _get_cufft_version(cuda_major: str) -> str:
153
+ """Get cufft library version based on CUDA major version.
154
+
155
+ Args:
156
+ cuda_major: CUDA major version as string (e.g., "12", "13")
157
+
158
+ Returns:
159
+ cufft version as string
160
+ """
161
+ # cufft versions: CUDA 12.x -> 11, CUDA 13.x -> 12
162
+ return "12" if cuda_major == "13" else "11"
163
+
164
+
165
+ def _get_nvidia_dll_paths(is_windows: bool, cuda: bool = True, cudnn: bool = True):
166
+ # Dynamically determine CUDA major version from build info
167
+ cuda_major_version = _extract_cuda_major_version(cuda_version)
168
+ cufft_version = _get_cufft_version(cuda_major_version)
169
+
170
+ if is_windows:
171
+ # Path is relative to site-packages directory.
172
+ cuda_dll_paths = [
173
+ ("nvidia", "cublas", "bin", f"cublasLt64_{cuda_major_version}.dll"),
174
+ ("nvidia", "cublas", "bin", f"cublas64_{cuda_major_version}.dll"),
175
+ ("nvidia", "cufft", "bin", f"cufft64_{cufft_version}.dll"),
176
+ ("nvidia", "cuda_runtime", "bin", f"cudart64_{cuda_major_version}.dll"),
177
+ ]
178
+ cudnn_dll_paths = [
179
+ ("nvidia", "cudnn", "bin", "cudnn_engines_runtime_compiled64_9.dll"),
180
+ ("nvidia", "cudnn", "bin", "cudnn_engines_precompiled64_9.dll"),
181
+ ("nvidia", "cudnn", "bin", "cudnn_heuristic64_9.dll"),
182
+ ("nvidia", "cudnn", "bin", "cudnn_ops64_9.dll"),
183
+ ("nvidia", "cudnn", "bin", "cudnn_adv64_9.dll"),
184
+ ("nvidia", "cudnn", "bin", "cudnn_graph64_9.dll"),
185
+ ("nvidia", "cudnn", "bin", "cudnn64_9.dll"),
186
+ ]
187
+ else: # Linux
188
+ # cublas64 depends on cublasLt64, so cublasLt64 should be loaded first.
189
+ cuda_dll_paths = [
190
+ ("nvidia", "cublas", "lib", f"libcublasLt.so.{cuda_major_version}"),
191
+ ("nvidia", "cublas", "lib", f"libcublas.so.{cuda_major_version}"),
192
+ ("nvidia", "cuda_nvrtc", "lib", f"libnvrtc.so.{cuda_major_version}"),
193
+ ("nvidia", "curand", "lib", "libcurand.so.10"),
194
+ ("nvidia", "cufft", "lib", f"libcufft.so.{cufft_version}"),
195
+ ("nvidia", "cuda_runtime", "lib", f"libcudart.so.{cuda_major_version}"),
196
+ ]
197
+
198
+ # Do not load cudnn sub DLLs (they will be dynamically loaded later) to be consistent with PyTorch in Linux.
199
+ cudnn_dll_paths = [
200
+ ("nvidia", "cudnn", "lib", "libcudnn.so.9"),
201
+ ]
202
+
203
+ return (cuda_dll_paths if cuda else []) + (cudnn_dll_paths if cudnn else [])
204
+
205
+
206
+ def print_debug_info():
207
+ """Print information to help debugging."""
208
+ import importlib.util # noqa: PLC0415
209
+ import os # noqa: PLC0415
210
+ import platform # noqa: PLC0415
211
+ from importlib.metadata import distributions # noqa: PLC0415
212
+
213
+ print(f"{package_name} version: {__version__}")
214
+ if cuda_version:
215
+ print(f"CUDA version used in build: {cuda_version}")
216
+ print("platform:", platform.platform())
217
+
218
+ print("\nPython package, version and location:")
219
+ ort_packages = []
220
+ for dist in distributions():
221
+ package = dist.metadata["Name"]
222
+ if package == "onnxruntime" or package.startswith(("onnxruntime-", "ort-")):
223
+ # Exclude packages whose root directory name is not onnxruntime.
224
+ location = _get_package_root(package, "onnxruntime")
225
+ if location and (package not in ort_packages):
226
+ ort_packages.append(package)
227
+ print(f"{package}=={dist.version} at {location}")
228
+
229
+ if len(ort_packages) > 1:
230
+ print(
231
+ "\033[33mWARNING: multiple onnxruntime packages are installed to the same location. "
232
+ "Please 'pip uninstall` all above packages, then `pip install` only one of them.\033[0m"
233
+ )
234
+
235
+ if cuda_version:
236
+ # Print version of installed packages that is related to CUDA or cuDNN DLLs.
237
+ cuda_major = _extract_cuda_major_version(cuda_version)
238
+
239
+ packages = [
240
+ "torch",
241
+ f"nvidia-cuda-runtime-cu{cuda_major}",
242
+ f"nvidia-cudnn-cu{cuda_major}",
243
+ f"nvidia-cublas-cu{cuda_major}",
244
+ f"nvidia-cufft-cu{cuda_major}",
245
+ f"nvidia-curand-cu{cuda_major}",
246
+ f"nvidia-cuda-nvrtc-cu{cuda_major}",
247
+ f"nvidia-nvjitlink-cu{cuda_major}",
248
+ ]
249
+ for package in packages:
250
+ directory_name = "nvidia" if package.startswith("nvidia-") else None
251
+ version = _get_package_version(package)
252
+ if version:
253
+ print(f"{package}=={version} at {_get_package_root(package, directory_name)}")
254
+ else:
255
+ print(f"{package} not installed")
256
+
257
+ if platform.system() == "Windows":
258
+ print(f"\nEnvironment variable:\nPATH={os.environ.get('PATH', '(unset)')}")
259
+ elif platform.system() == "Linux":
260
+ print(f"\nEnvironment variable:\nLD_LIBRARY_PATH={os.environ.get('LD_LIBRARY_PATH', '(unset)')}")
261
+
262
+ if importlib.util.find_spec("psutil"):
263
+
264
+ def is_target_dll(path: str):
265
+ target_keywords = ["vcruntime140", "msvcp140"]
266
+ if cuda_version:
267
+ target_keywords = ["cufft", "cublas", "cudart", "nvrtc", "curand", "cudnn", *target_keywords]
268
+ return any(keyword in path for keyword in target_keywords)
269
+
270
+ import psutil # noqa: PLC0415
271
+
272
+ p = psutil.Process(os.getpid())
273
+
274
+ print("\nList of loaded DLLs:")
275
+ for lib in p.memory_maps():
276
+ if is_target_dll(lib.path.lower()):
277
+ print(lib.path)
278
+
279
+ if cuda_version:
280
+ if importlib.util.find_spec("cpuinfo") and importlib.util.find_spec("py3nvml"):
281
+ from .transformers.machine_info import get_device_info # noqa: PLC0415
282
+
283
+ print("\nDevice information:")
284
+ print(get_device_info())
285
+ else:
286
+ print("please `pip install py-cpuinfo py3nvml` to show device information.")
287
+ else:
288
+ print("please `pip install psutil` to show loaded DLLs.")
289
+
290
+
291
+ def preload_dlls(cuda: bool = True, cudnn: bool = True, msvc: bool = True, directory=None):
292
+ """Preload CUDA 12.x+ and cuDNN 9.x DLLs in Windows or Linux, and MSVC runtime DLLs in Windows.
293
+
294
+ When the installed PyTorch is compatible (using same major version of CUDA and cuDNN),
295
+ there is no need to call this function if `import torch` is done before `import onnxruntime`.
296
+
297
+ Args:
298
+ cuda (bool, optional): enable loading CUDA DLLs. Defaults to True.
299
+ cudnn (bool, optional): enable loading cuDNN DLLs. Defaults to True.
300
+ msvc (bool, optional): enable loading MSVC DLLs in Windows. Defaults to True.
301
+ directory(str, optional): a directory contains CUDA or cuDNN DLLs. It can be an absolute path,
302
+ or a path relative to the directory of this file.
303
+ If directory is None (default value), the search order: the lib directory of compatible PyTorch in Windows,
304
+ nvidia site packages, default DLL loading paths.
305
+ If directory is empty string (""), the search order: nvidia site packages, default DLL loading paths.
306
+ If directory is a path, the search order: the directory, default DLL loading paths.
307
+ """
308
+ import ctypes # noqa: PLC0415
309
+ import os # noqa: PLC0415
310
+ import platform # noqa: PLC0415
311
+ import sys # noqa: PLC0415
312
+
313
+ if platform.system() not in ["Windows", "Linux"]:
314
+ return
315
+
316
+ is_windows = platform.system() == "Windows"
317
+ if is_windows and msvc:
318
+ try:
319
+ ctypes.CDLL("vcruntime140.dll")
320
+ ctypes.CDLL("msvcp140.dll")
321
+ if platform.machine() != "ARM64":
322
+ ctypes.CDLL("vcruntime140_1.dll")
323
+ except OSError:
324
+ print("Microsoft Visual C++ Redistributable is not installed, this may lead to the DLL load failure.")
325
+ print("It can be downloaded at https://aka.ms/vs/17/release/vc_redist.x64.exe.")
326
+
327
+ # Check if CUDA version is supported (12.x or 13.x+)
328
+ ort_cuda_major = None
329
+ if cuda_version:
330
+ try:
331
+ ort_cuda_major = int(cuda_version.split(".")[0])
332
+ if ort_cuda_major < 12 and (cuda or cudnn):
333
+ print(
334
+ f"\033[33mWARNING: {package_name} is built with CUDA {cuda_version}, which is not supported for preloading. "
335
+ f"CUDA 12.x or newer is required. Call preload_dlls with cuda=False and cudnn=False.\033[0m"
336
+ )
337
+ return
338
+ except ValueError:
339
+ print(
340
+ f"\033[33mWARNING: Unable to parse CUDA version '{cuda_version}'. "
341
+ "Skipping DLL preloading. Call preload_dlls with cuda=False and cudnn=False.\033[0m"
342
+ )
343
+ return
344
+ elif cuda or cudnn:
345
+ # No CUDA version info available but CUDA/cuDNN preloading requested
346
+ return
347
+
348
+ is_cuda_cudnn_imported_by_torch = False
349
+
350
+ if is_windows:
351
+ torch_version = _get_package_version("torch")
352
+ # Check if torch CUDA version matches onnxruntime CUDA version
353
+ torch_cuda_major = None
354
+ if torch_version and "+cu" in torch_version:
355
+ with contextlib.suppress(ValueError):
356
+ # Extract CUDA version from torch (e.g., "2.0.0+cu121" -> 12)
357
+ cu_part = torch_version.split("+cu")[1]
358
+ torch_cuda_major = int(cu_part[:2]) # First 2 digits are major version
359
+
360
+ is_torch_cuda_compatible = (
361
+ torch_cuda_major == ort_cuda_major if (torch_cuda_major and ort_cuda_major) else False
362
+ )
363
+
364
+ if "torch" in sys.modules:
365
+ is_cuda_cudnn_imported_by_torch = is_torch_cuda_compatible
366
+ if torch_cuda_major and ort_cuda_major and torch_cuda_major != ort_cuda_major:
367
+ print(
368
+ f"\033[33mWARNING: The installed PyTorch {torch_version} uses CUDA {torch_cuda_major}.x, "
369
+ f"but {package_name} is built with CUDA {ort_cuda_major}.x. "
370
+ f"Please install PyTorch for CUDA {ort_cuda_major}.x to be compatible.\033[0m"
371
+ )
372
+
373
+ if is_torch_cuda_compatible and directory is None:
374
+ torch_root = _get_package_root("torch", "torch")
375
+ if torch_root:
376
+ directory = os.path.join(torch_root, "lib")
377
+
378
+ base_directory = directory or ".."
379
+ if not os.path.isabs(base_directory):
380
+ base_directory = os.path.join(os.path.dirname(__file__), base_directory)
381
+ base_directory = os.path.normpath(base_directory)
382
+ if not os.path.isdir(base_directory):
383
+ raise RuntimeError(f"Invalid parameter of directory={directory}. The directory does not exist!")
384
+
385
+ if is_cuda_cudnn_imported_by_torch:
386
+ # In Windows, PyTorch has loaded CUDA and cuDNN DLLs during `import torch`, no need to load them again.
387
+ print("Skip loading CUDA and cuDNN DLLs since torch is imported.")
388
+ return
389
+
390
+ # Try load DLLs from nvidia site packages.
391
+ dll_paths = _get_nvidia_dll_paths(is_windows, cuda, cudnn)
392
+ loaded_dlls = []
393
+ for relative_path in dll_paths:
394
+ dll_path = (
395
+ os.path.join(base_directory, relative_path[-1])
396
+ if directory
397
+ else os.path.join(base_directory, *relative_path)
398
+ )
399
+ if os.path.isfile(dll_path):
400
+ try:
401
+ _ = ctypes.CDLL(dll_path)
402
+ loaded_dlls.append(relative_path[-1])
403
+ except Exception as e:
404
+ print(f"Failed to load {dll_path}: {e}")
405
+
406
+ # Try load DLLs with default path settings.
407
+ has_failure = False
408
+ for relative_path in dll_paths:
409
+ dll_filename = relative_path[-1]
410
+ if dll_filename not in loaded_dlls:
411
+ try:
412
+ _ = ctypes.CDLL(dll_filename)
413
+ except Exception as e:
414
+ has_failure = True
415
+ print(f"Failed to load {dll_filename}: {e}")
416
+
417
+ if has_failure:
418
+ print("Please follow https://onnxruntime.ai/docs/install/#cuda-and-cudnn to install CUDA and CuDNN.")
@@ -0,0 +1,6 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+
6
+ from .backend import is_compatible, prepare, run, supports_device # noqa: F401
@@ -0,0 +1,175 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+ """
6
+ Implements ONNX's backend API.
7
+ """
8
+
9
+ import os
10
+ import unittest
11
+
12
+ import packaging.version
13
+ from onnx import ModelProto, helper, version # noqa: F401
14
+ from onnx.backend.base import Backend
15
+ from onnx.checker import check_model
16
+
17
+ from onnxruntime import InferenceSession, SessionOptions, get_available_providers, get_device
18
+ from onnxruntime.backend.backend_rep import OnnxRuntimeBackendRep
19
+
20
+
21
+ class OnnxRuntimeBackend(Backend):
22
+ """
23
+ Implements
24
+ `ONNX's backend API <https://github.com/onnx/onnx/blob/main/docs/ImplementingAnOnnxBackend.md>`_
25
+ with *ONNX Runtime*.
26
+ The backend is mostly used when you need to switch between
27
+ multiple runtimes with the same API.
28
+ `Importing models from ONNX to Caffe2 <https://github.com/onnx/tutorials/blob/master/tutorials/OnnxCaffe2Import.ipynb>`_
29
+ shows how to use *caffe2* as a backend for a converted model.
30
+ Note: This is not the official Python API.
31
+ """
32
+
33
+ allowReleasedOpsetsOnly = bool(os.getenv("ALLOW_RELEASED_ONNX_OPSET_ONLY", "1") == "1") # noqa: N815
34
+
35
+ @classmethod
36
+ def is_compatible(cls, model, device=None, **kwargs):
37
+ """
38
+ Return whether the model is compatible with the backend.
39
+
40
+ :param model: unused
41
+ :param device: None to use the default device or a string (ex: `'CPU'`)
42
+ :return: boolean
43
+ """
44
+ if device is None:
45
+ device = get_device()
46
+ return cls.supports_device(device)
47
+
48
+ @classmethod
49
+ def is_opset_supported(cls, model):
50
+ """
51
+ Return whether the opset for the model is supported by the backend.
52
+ When By default only released onnx opsets are allowed by the backend
53
+ To test new opsets env variable ALLOW_RELEASED_ONNX_OPSET_ONLY should be set to 0
54
+
55
+ :param model: Model whose opsets needed to be verified.
56
+ :return: boolean and error message if opset is not supported.
57
+ """
58
+ if cls.allowReleasedOpsetsOnly:
59
+ for opset in model.opset_import:
60
+ domain = opset.domain if opset.domain else "ai.onnx"
61
+ try:
62
+ key = (domain, opset.version)
63
+ if key not in helper.OP_SET_ID_VERSION_MAP:
64
+ error_message = (
65
+ "Skipping this test as only released onnx opsets are supported."
66
+ "To run this test set env variable ALLOW_RELEASED_ONNX_OPSET_ONLY to 0."
67
+ f" Got Domain '{domain}' version '{opset.version}'."
68
+ )
69
+ return False, error_message
70
+ except AttributeError:
71
+ # for some CI pipelines accessing helper.OP_SET_ID_VERSION_MAP
72
+ # is generating attribute error. TODO investigate the pipelines to
73
+ # fix this error. Falling back to a simple version check when this error is encountered
74
+ if (domain == "ai.onnx" and opset.version > 12) or (domain == "ai.ommx.ml" and opset.version > 2):
75
+ error_message = (
76
+ "Skipping this test as only released onnx opsets are supported."
77
+ "To run this test set env variable ALLOW_RELEASED_ONNX_OPSET_ONLY to 0."
78
+ f" Got Domain '{domain}' version '{opset.version}'."
79
+ )
80
+ return False, error_message
81
+ return True, ""
82
+
83
+ @classmethod
84
+ def supports_device(cls, device):
85
+ """
86
+ Check whether the backend is compiled with particular device support.
87
+ In particular it's used in the testing suite.
88
+ """
89
+ if device == "CUDA":
90
+ device = "GPU"
91
+ return "-" + device in get_device() or device + "-" in get_device() or device == get_device()
92
+
93
+ @classmethod
94
+ def prepare(cls, model, device=None, **kwargs):
95
+ """
96
+ Load the model and creates a :class:`onnxruntime.InferenceSession`
97
+ ready to be used as a backend.
98
+
99
+ :param model: ModelProto (returned by `onnx.load`),
100
+ string for a filename or bytes for a serialized model
101
+ :param device: requested device for the computation,
102
+ None means the default one which depends on
103
+ the compilation settings
104
+ :param kwargs: see :class:`onnxruntime.SessionOptions`
105
+ :return: :class:`onnxruntime.InferenceSession`
106
+ """
107
+ if isinstance(model, OnnxRuntimeBackendRep):
108
+ return model
109
+ elif isinstance(model, InferenceSession):
110
+ return OnnxRuntimeBackendRep(model)
111
+ elif isinstance(model, (str, bytes)):
112
+ options = SessionOptions()
113
+ for k, v in kwargs.items():
114
+ if hasattr(options, k):
115
+ setattr(options, k, v)
116
+
117
+ excluded_providers = os.getenv("ORT_ONNX_BACKEND_EXCLUDE_PROVIDERS", default="").split(",")
118
+ providers = [x for x in get_available_providers() if (x not in excluded_providers)]
119
+
120
+ inf = InferenceSession(model, sess_options=options, providers=providers)
121
+ # backend API is primarily used for ONNX test/validation. As such, we should disable session.run() fallback
122
+ # which may hide test failures.
123
+ inf.disable_fallback()
124
+ if device is not None and not cls.supports_device(device):
125
+ raise RuntimeError(f"Incompatible device expected '{device}', got '{get_device()}'")
126
+ return cls.prepare(inf, device, **kwargs)
127
+ else:
128
+ # type: ModelProto
129
+ # check_model serializes the model anyways, so serialize the model once here
130
+ # and reuse it below in the cls.prepare call to avoid an additional serialization
131
+ # only works with onnx >= 1.10.0 hence the version check
132
+ onnx_version = packaging.version.parse(version.version) or packaging.version.Version("0")
133
+ onnx_supports_serialized_model_check = onnx_version.release >= (1, 10, 0)
134
+ bin_or_model = model.SerializeToString() if onnx_supports_serialized_model_check else model
135
+ check_model(bin_or_model)
136
+ opset_supported, error_message = cls.is_opset_supported(model)
137
+ if not opset_supported:
138
+ raise unittest.SkipTest(error_message)
139
+ # Now bin might be serialized, if it's not we need to serialize it otherwise we'll have
140
+ # an infinite recursive call
141
+ bin = bin_or_model
142
+ if not isinstance(bin, (str, bytes)):
143
+ bin = bin.SerializeToString()
144
+ return cls.prepare(bin, device, **kwargs)
145
+
146
+ @classmethod
147
+ def run_model(cls, model, inputs, device=None, **kwargs):
148
+ """
149
+ Compute the prediction.
150
+
151
+ :param model: :class:`onnxruntime.InferenceSession` returned
152
+ by function *prepare*
153
+ :param inputs: inputs
154
+ :param device: requested device for the computation,
155
+ None means the default one which depends on
156
+ the compilation settings
157
+ :param kwargs: see :class:`onnxruntime.RunOptions`
158
+ :return: predictions
159
+ """
160
+ rep = cls.prepare(model, device, **kwargs)
161
+ return rep.run(inputs, **kwargs)
162
+
163
+ @classmethod
164
+ def run_node(cls, node, inputs, device=None, outputs_info=None, **kwargs):
165
+ """
166
+ This method is not implemented as it is much more efficient
167
+ to run a whole model than every node independently.
168
+ """
169
+ raise NotImplementedError("It is much more efficient to run a whole model than every node independently.")
170
+
171
+
172
+ is_compatible = OnnxRuntimeBackend.is_compatible
173
+ prepare = OnnxRuntimeBackend.prepare
174
+ run = OnnxRuntimeBackend.run_model
175
+ supports_device = OnnxRuntimeBackend.supports_device
@@ -0,0 +1,52 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+ """
6
+ Implements ONNX's backend API.
7
+ """
8
+
9
+ from onnx.backend.base import BackendRep
10
+
11
+ from onnxruntime import RunOptions
12
+
13
+
14
+ class OnnxRuntimeBackendRep(BackendRep):
15
+ """
16
+ Computes the prediction for a pipeline converted into
17
+ an :class:`onnxruntime.InferenceSession` node.
18
+ """
19
+
20
+ def __init__(self, session):
21
+ """
22
+ :param session: :class:`onnxruntime.InferenceSession`
23
+ """
24
+ self._session = session
25
+
26
+ def run(self, inputs, **kwargs): # type: (Any, **Any) -> Tuple[Any, ...]
27
+ """
28
+ Computes the prediction.
29
+ See :meth:`onnxruntime.InferenceSession.run`.
30
+ """
31
+
32
+ options = RunOptions()
33
+ for k, v in kwargs.items():
34
+ if hasattr(options, k):
35
+ setattr(options, k, v)
36
+
37
+ if isinstance(inputs, list):
38
+ inps = {}
39
+ for i, inp in enumerate(self._session.get_inputs()):
40
+ inps[inp.name] = inputs[i]
41
+ outs = self._session.run(None, inps, options)
42
+ if isinstance(outs, list):
43
+ return outs
44
+ else:
45
+ output_names = [o.name for o in self._session.get_outputs()]
46
+ return [outs[name] for name in output_names]
47
+ else:
48
+ inp = self._session.get_inputs()
49
+ if len(inp) != 1:
50
+ raise RuntimeError(f"Model expect {len(inp)} inputs")
51
+ inps = {inp[0].name: inputs}
52
+ return self._session.run(None, inps, options)
Binary file
@@ -0,0 +1,4 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
@@ -0,0 +1,7 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+
6
+ # This file can be modified by setup.py when building a manylinux2010 wheel
7
+ # When modified, it will preload some libraries needed for the python C extension