onnxruntime-directml 1.20.0__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (305) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6508 -0
  4. onnxruntime/__init__.py +78 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +174 -0
  7. onnxruntime/backend/backend_rep.py +53 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  13. onnxruntime/capi/onnxruntime.dll +0 -0
  14. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  15. onnxruntime/capi/onnxruntime_inference_collection.py +1108 -0
  16. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  17. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  18. onnxruntime/capi/onnxruntime_validation.py +150 -0
  19. onnxruntime/capi/version_info.py +2 -0
  20. onnxruntime/datasets/__init__.py +17 -0
  21. onnxruntime/datasets/logreg_iris.onnx +0 -0
  22. onnxruntime/datasets/mul_1.onnx +0 -0
  23. onnxruntime/datasets/sigmoid.onnx +13 -0
  24. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  27. onnxruntime/quantization/__init__.py +16 -0
  28. onnxruntime/quantization/base_quantizer.py +532 -0
  29. onnxruntime/quantization/calibrate.py +1245 -0
  30. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  31. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  32. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  33. onnxruntime/quantization/execution_providers/qnn/preprocess.py +307 -0
  34. onnxruntime/quantization/execution_providers/qnn/quant_config.py +387 -0
  35. onnxruntime/quantization/fusions/__init__.py +3 -0
  36. onnxruntime/quantization/fusions/fusion.py +311 -0
  37. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  38. onnxruntime/quantization/fusions/fusion_layernorm.py +135 -0
  39. onnxruntime/quantization/matmul_4bits_quantizer.py +1480 -0
  40. onnxruntime/quantization/matmul_bnb4_quantizer.py +240 -0
  41. onnxruntime/quantization/onnx_model.py +580 -0
  42. onnxruntime/quantization/onnx_quantizer.py +1008 -0
  43. onnxruntime/quantization/operators/__init__.py +2 -0
  44. onnxruntime/quantization/operators/activation.py +119 -0
  45. onnxruntime/quantization/operators/argmax.py +18 -0
  46. onnxruntime/quantization/operators/attention.py +73 -0
  47. onnxruntime/quantization/operators/base_operator.py +26 -0
  48. onnxruntime/quantization/operators/binary_op.py +72 -0
  49. onnxruntime/quantization/operators/concat.py +62 -0
  50. onnxruntime/quantization/operators/conv.py +258 -0
  51. onnxruntime/quantization/operators/direct_q8.py +78 -0
  52. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  53. onnxruntime/quantization/operators/gather.py +64 -0
  54. onnxruntime/quantization/operators/gavgpool.py +62 -0
  55. onnxruntime/quantization/operators/gemm.py +166 -0
  56. onnxruntime/quantization/operators/lstm.py +117 -0
  57. onnxruntime/quantization/operators/matmul.py +231 -0
  58. onnxruntime/quantization/operators/maxpool.py +34 -0
  59. onnxruntime/quantization/operators/norm.py +40 -0
  60. onnxruntime/quantization/operators/pad.py +100 -0
  61. onnxruntime/quantization/operators/pooling.py +67 -0
  62. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  63. onnxruntime/quantization/operators/resize.py +34 -0
  64. onnxruntime/quantization/operators/softmax.py +74 -0
  65. onnxruntime/quantization/operators/split.py +63 -0
  66. onnxruntime/quantization/operators/where.py +87 -0
  67. onnxruntime/quantization/preprocess.py +141 -0
  68. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  69. onnxruntime/quantization/qdq_quantizer.py +1187 -0
  70. onnxruntime/quantization/quant_utils.py +891 -0
  71. onnxruntime/quantization/quantize.py +748 -0
  72. onnxruntime/quantization/registry.py +106 -0
  73. onnxruntime/quantization/shape_inference.py +187 -0
  74. onnxruntime/quantization/tensor_quant_overrides.py +516 -0
  75. onnxruntime/tools/__init__.py +10 -0
  76. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  77. onnxruntime/tools/convert_onnx_models_to_ort.py +377 -0
  78. onnxruntime/tools/file_utils.py +46 -0
  79. onnxruntime/tools/logger.py +11 -0
  80. onnxruntime/tools/make_dynamic_shape_fixed.py +72 -0
  81. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  82. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +33 -0
  83. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  84. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  85. onnxruntime/tools/mobile_helpers/usability_checker.py +739 -0
  86. onnxruntime/tools/offline_tuning.py +169 -0
  87. onnxruntime/tools/onnx_model_utils.py +413 -0
  88. onnxruntime/tools/onnx_randomizer.py +85 -0
  89. onnxruntime/tools/onnxruntime_test.py +164 -0
  90. onnxruntime/tools/optimize_onnx_model.py +55 -0
  91. onnxruntime/tools/ort_format_model/__init__.py +25 -0
  92. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +663 -0
  93. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  94. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  95. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  96. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  97. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  98. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  99. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  100. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  140. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  141. onnxruntime/tools/ort_format_model/types.py +84 -0
  142. onnxruntime/tools/ort_format_model/utils.py +62 -0
  143. onnxruntime/tools/pytorch_export_contrib_ops.py +108 -0
  144. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  145. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  146. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  147. onnxruntime/tools/reduced_build_config_parser.py +202 -0
  148. onnxruntime/tools/symbolic_shape_infer.py +3016 -0
  149. onnxruntime/tools/update_onnx_opset.py +31 -0
  150. onnxruntime/transformers/__init__.py +8 -0
  151. onnxruntime/transformers/affinity_helper.py +40 -0
  152. onnxruntime/transformers/benchmark.py +944 -0
  153. onnxruntime/transformers/benchmark_helper.py +646 -0
  154. onnxruntime/transformers/bert_perf_test.py +634 -0
  155. onnxruntime/transformers/bert_test_data.py +642 -0
  156. onnxruntime/transformers/compare_bert_results.py +246 -0
  157. onnxruntime/transformers/constants.py +47 -0
  158. onnxruntime/transformers/convert_generation.py +3124 -0
  159. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  160. onnxruntime/transformers/convert_to_packing_mode.py +387 -0
  161. onnxruntime/transformers/dynamo_onnx_helper.py +104 -0
  162. onnxruntime/transformers/float16.py +501 -0
  163. onnxruntime/transformers/fusion_attention.py +1235 -0
  164. onnxruntime/transformers/fusion_attention_clip.py +257 -0
  165. onnxruntime/transformers/fusion_attention_sam2.py +534 -0
  166. onnxruntime/transformers/fusion_attention_unet.py +1304 -0
  167. onnxruntime/transformers/fusion_attention_vae.py +301 -0
  168. onnxruntime/transformers/fusion_bart_attention.py +640 -0
  169. onnxruntime/transformers/fusion_base.py +137 -0
  170. onnxruntime/transformers/fusion_bias_add.py +58 -0
  171. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  172. onnxruntime/transformers/fusion_biassplitgelu.py +111 -0
  173. onnxruntime/transformers/fusion_conformer_attention.py +143 -0
  174. onnxruntime/transformers/fusion_embedlayer.py +811 -0
  175. onnxruntime/transformers/fusion_fastgelu.py +360 -0
  176. onnxruntime/transformers/fusion_gelu.py +259 -0
  177. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  178. onnxruntime/transformers/fusion_gemmfastgelu.py +122 -0
  179. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  180. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  181. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  182. onnxruntime/transformers/fusion_group_norm.py +179 -0
  183. onnxruntime/transformers/fusion_layernorm.py +465 -0
  184. onnxruntime/transformers/fusion_nhwc_conv.py +100 -0
  185. onnxruntime/transformers/fusion_options.py +340 -0
  186. onnxruntime/transformers/fusion_qordered_attention.py +421 -0
  187. onnxruntime/transformers/fusion_qordered_gelu.py +119 -0
  188. onnxruntime/transformers/fusion_qordered_layernorm.py +123 -0
  189. onnxruntime/transformers/fusion_qordered_matmul.py +217 -0
  190. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  191. onnxruntime/transformers/fusion_reshape.py +173 -0
  192. onnxruntime/transformers/fusion_rotary_attention.py +1592 -0
  193. onnxruntime/transformers/fusion_shape.py +110 -0
  194. onnxruntime/transformers/fusion_simplified_layernorm.py +159 -0
  195. onnxruntime/transformers/fusion_skip_group_norm.py +255 -0
  196. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  197. onnxruntime/transformers/fusion_transpose.py +168 -0
  198. onnxruntime/transformers/fusion_utils.py +307 -0
  199. onnxruntime/transformers/huggingface_models.py +167 -0
  200. onnxruntime/transformers/import_utils.py +20 -0
  201. onnxruntime/transformers/io_binding_helper.py +442 -0
  202. onnxruntime/transformers/large_model_exporter.py +395 -0
  203. onnxruntime/transformers/machine_info.py +221 -0
  204. onnxruntime/transformers/metrics.py +164 -0
  205. onnxruntime/transformers/models/bart/__init__.py +12 -0
  206. onnxruntime/transformers/models/bart/export.py +98 -0
  207. onnxruntime/transformers/models/bert/__init__.py +12 -0
  208. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  209. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  210. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  211. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +561 -0
  212. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1032 -0
  213. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  214. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  215. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  216. onnxruntime/transformers/models/llama/__init__.py +12 -0
  217. onnxruntime/transformers/models/llama/benchmark.py +703 -0
  218. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  219. onnxruntime/transformers/models/llama/benchmark_e2e.py +606 -0
  220. onnxruntime/transformers/models/llama/convert_to_onnx.py +1027 -0
  221. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  222. onnxruntime/transformers/models/llama/llama_inputs.py +503 -0
  223. onnxruntime/transformers/models/llama/llama_parity.py +309 -0
  224. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  225. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  226. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  227. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  228. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  229. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  230. onnxruntime/transformers/models/longformer/longformer_helper.py +77 -0
  231. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  232. onnxruntime/transformers/models/phi2/convert_to_onnx.py +576 -0
  233. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  234. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  235. onnxruntime/transformers/models/sam2/benchmark_sam2.py +625 -0
  236. onnxruntime/transformers/models/sam2/convert_to_onnx.py +260 -0
  237. onnxruntime/transformers/models/sam2/image_decoder.py +273 -0
  238. onnxruntime/transformers/models/sam2/image_encoder.py +186 -0
  239. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  240. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  241. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  242. onnxruntime/transformers/models/sam2/sam2_demo.py +322 -0
  243. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +280 -0
  244. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  245. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  246. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1429 -0
  247. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  248. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +102 -0
  249. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +268 -0
  250. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  251. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1319 -0
  252. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1181 -0
  253. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +296 -0
  254. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +388 -0
  255. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  256. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  257. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  258. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +350 -0
  259. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  260. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  261. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  262. onnxruntime/transformers/models/t5/__init__.py +12 -0
  263. onnxruntime/transformers/models/t5/convert_to_onnx.py +278 -0
  264. onnxruntime/transformers/models/t5/past_helper.py +150 -0
  265. onnxruntime/transformers/models/t5/t5_decoder.py +438 -0
  266. onnxruntime/transformers/models/t5/t5_encoder.py +171 -0
  267. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +299 -0
  268. onnxruntime/transformers/models/t5/t5_helper.py +272 -0
  269. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  270. onnxruntime/transformers/models/whisper/benchmark.py +610 -0
  271. onnxruntime/transformers/models/whisper/benchmark_all.py +528 -0
  272. onnxruntime/transformers/models/whisper/convert_to_onnx.py +536 -0
  273. onnxruntime/transformers/models/whisper/whisper_chain.py +329 -0
  274. onnxruntime/transformers/models/whisper/whisper_decoder.py +402 -0
  275. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  276. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +306 -0
  277. onnxruntime/transformers/models/whisper/whisper_helper.py +524 -0
  278. onnxruntime/transformers/models/whisper/whisper_openai_helper.py +84 -0
  279. onnxruntime/transformers/onnx_exporter.py +717 -0
  280. onnxruntime/transformers/onnx_model.py +1569 -0
  281. onnxruntime/transformers/onnx_model_bart.py +142 -0
  282. onnxruntime/transformers/onnx_model_bert.py +481 -0
  283. onnxruntime/transformers/onnx_model_bert_keras.py +475 -0
  284. onnxruntime/transformers/onnx_model_bert_tf.py +589 -0
  285. onnxruntime/transformers/onnx_model_clip.py +40 -0
  286. onnxruntime/transformers/onnx_model_conformer.py +33 -0
  287. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  288. onnxruntime/transformers/onnx_model_phi.py +930 -0
  289. onnxruntime/transformers/onnx_model_sam2.py +138 -0
  290. onnxruntime/transformers/onnx_model_t5.py +791 -0
  291. onnxruntime/transformers/onnx_model_tnlr.py +227 -0
  292. onnxruntime/transformers/onnx_model_unet.py +259 -0
  293. onnxruntime/transformers/onnx_model_vae.py +43 -0
  294. onnxruntime/transformers/onnx_utils.py +55 -0
  295. onnxruntime/transformers/optimizer.py +612 -0
  296. onnxruntime/transformers/profiler.py +725 -0
  297. onnxruntime/transformers/quantize_helper.py +76 -0
  298. onnxruntime/transformers/shape_infer_helper.py +122 -0
  299. onnxruntime/transformers/shape_optimizer.py +401 -0
  300. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  301. onnxruntime_directml-1.20.0.dist-info/METADATA +187 -0
  302. onnxruntime_directml-1.20.0.dist-info/RECORD +305 -0
  303. onnxruntime_directml-1.20.0.dist-info/WHEEL +5 -0
  304. onnxruntime_directml-1.20.0.dist-info/entry_points.txt +2 -0
  305. onnxruntime_directml-1.20.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,78 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+ """
6
+ ONNX Runtime is a performance-focused scoring engine for Open Neural Network Exchange (ONNX) models.
7
+ For more information on ONNX Runtime, please see `aka.ms/onnxruntime <https://aka.ms/onnxruntime/>`_
8
+ or the `Github project <https://github.com/microsoft/onnxruntime/>`_.
9
+ """
10
+ __version__ = "1.20.0"
11
+ __author__ = "Microsoft"
12
+
13
+ # we need to do device version validation (for example to check Cuda version for an onnxruntime-training package).
14
+ # in order to know whether the onnxruntime package is for training it needs
15
+ # to do import onnxruntime.training.ortmodule first.
16
+ # onnxruntime.capi._pybind_state is required before import onnxruntime.training.ortmodule.
17
+ # however, import onnxruntime.capi._pybind_state will already raise an exception if a required Cuda version
18
+ # is not found.
19
+ # here we need to save the exception and continue with Cuda version validation in order to post
20
+ # meaningful messages to the user.
21
+ # the saved exception is raised after device version validation.
22
+ try:
23
+ from onnxruntime.capi._pybind_state import ExecutionMode # noqa: F401
24
+ from onnxruntime.capi._pybind_state import ExecutionOrder # noqa: F401
25
+ from onnxruntime.capi._pybind_state import GraphOptimizationLevel # noqa: F401
26
+ from onnxruntime.capi._pybind_state import LoraAdapter # noqa: F401
27
+ from onnxruntime.capi._pybind_state import ModelMetadata # noqa: F401
28
+ from onnxruntime.capi._pybind_state import NodeArg # noqa: F401
29
+ from onnxruntime.capi._pybind_state import OrtAllocatorType # noqa: F401
30
+ from onnxruntime.capi._pybind_state import OrtArenaCfg # noqa: F401
31
+ from onnxruntime.capi._pybind_state import OrtMemoryInfo # noqa: F401
32
+ from onnxruntime.capi._pybind_state import OrtMemType # noqa: F401
33
+ from onnxruntime.capi._pybind_state import OrtSparseFormat # noqa: F401
34
+ from onnxruntime.capi._pybind_state import RunOptions # noqa: F401
35
+ from onnxruntime.capi._pybind_state import SessionIOBinding # noqa: F401
36
+ from onnxruntime.capi._pybind_state import SessionOptions # noqa: F401
37
+ from onnxruntime.capi._pybind_state import create_and_register_allocator # noqa: F401
38
+ from onnxruntime.capi._pybind_state import create_and_register_allocator_v2 # noqa: F401
39
+ from onnxruntime.capi._pybind_state import disable_telemetry_events # noqa: F401
40
+ from onnxruntime.capi._pybind_state import enable_telemetry_events # noqa: F401
41
+ from onnxruntime.capi._pybind_state import get_all_providers # noqa: F401
42
+ from onnxruntime.capi._pybind_state import get_available_providers # noqa: F401
43
+ from onnxruntime.capi._pybind_state import get_build_info # noqa: F401
44
+ from onnxruntime.capi._pybind_state import get_device # noqa: F401
45
+ from onnxruntime.capi._pybind_state import get_version_string # noqa: F401
46
+ from onnxruntime.capi._pybind_state import has_collective_ops # noqa: F401
47
+ from onnxruntime.capi._pybind_state import set_default_logger_severity # noqa: F401
48
+ from onnxruntime.capi._pybind_state import set_default_logger_verbosity # noqa: F401
49
+ from onnxruntime.capi._pybind_state import set_seed # noqa: F401
50
+
51
+ import_capi_exception = None
52
+ except Exception as e:
53
+ import_capi_exception = e
54
+
55
+ from onnxruntime.capi import onnxruntime_validation
56
+
57
+ if import_capi_exception:
58
+ raise import_capi_exception
59
+
60
+ from onnxruntime.capi.onnxruntime_inference_collection import AdapterFormat # noqa: F401
61
+ from onnxruntime.capi.onnxruntime_inference_collection import InferenceSession # noqa: F401
62
+ from onnxruntime.capi.onnxruntime_inference_collection import IOBinding # noqa: F401
63
+ from onnxruntime.capi.onnxruntime_inference_collection import OrtDevice # noqa: F401
64
+ from onnxruntime.capi.onnxruntime_inference_collection import OrtValue # noqa: F401
65
+ from onnxruntime.capi.onnxruntime_inference_collection import SparseTensor # noqa: F401
66
+
67
+ # TODO: thiagofc: Temporary experimental namespace for new PyTorch front-end
68
+ try: # noqa: SIM105
69
+ from . import experimental # noqa: F401
70
+ except ImportError:
71
+ pass
72
+
73
+ from onnxruntime.capi.onnxruntime_validation import cuda_version, package_name, version # noqa: F401
74
+
75
+ if version:
76
+ __version__ = version
77
+
78
+ onnxruntime_validation.check_distro_info()
@@ -0,0 +1,6 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+
6
+ from .backend import is_compatible, prepare, run, supports_device # noqa: F401
@@ -0,0 +1,174 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+ """
6
+ Implements ONNX's backend API.
7
+ """
8
+ import os
9
+ import unittest
10
+
11
+ import packaging.version
12
+ from onnx import ModelProto, helper, version # noqa: F401
13
+ from onnx.backend.base import Backend
14
+ from onnx.checker import check_model
15
+
16
+ from onnxruntime import InferenceSession, SessionOptions, get_available_providers, get_device
17
+ from onnxruntime.backend.backend_rep import OnnxRuntimeBackendRep
18
+
19
+
20
+ class OnnxRuntimeBackend(Backend):
21
+ """
22
+ Implements
23
+ `ONNX's backend API <https://github.com/onnx/onnx/blob/main/docs/ImplementingAnOnnxBackend.md>`_
24
+ with *ONNX Runtime*.
25
+ The backend is mostly used when you need to switch between
26
+ multiple runtimes with the same API.
27
+ `Importing models from ONNX to Caffe2 <https://github.com/onnx/tutorials/blob/master/tutorials/OnnxCaffe2Import.ipynb>`_
28
+ shows how to use *caffe2* as a backend for a converted model.
29
+ Note: This is not the official Python API.
30
+ """
31
+
32
+ allowReleasedOpsetsOnly = bool(os.getenv("ALLOW_RELEASED_ONNX_OPSET_ONLY", "1") == "1") # noqa: N815
33
+
34
+ @classmethod
35
+ def is_compatible(cls, model, device=None, **kwargs):
36
+ """
37
+ Return whether the model is compatible with the backend.
38
+
39
+ :param model: unused
40
+ :param device: None to use the default device or a string (ex: `'CPU'`)
41
+ :return: boolean
42
+ """
43
+ if device is None:
44
+ device = get_device()
45
+ return cls.supports_device(device)
46
+
47
+ @classmethod
48
+ def is_opset_supported(cls, model):
49
+ """
50
+ Return whether the opset for the model is supported by the backend.
51
+ When By default only released onnx opsets are allowed by the backend
52
+ To test new opsets env variable ALLOW_RELEASED_ONNX_OPSET_ONLY should be set to 0
53
+
54
+ :param model: Model whose opsets needed to be verified.
55
+ :return: boolean and error message if opset is not supported.
56
+ """
57
+ if cls.allowReleasedOpsetsOnly:
58
+ for opset in model.opset_import:
59
+ domain = opset.domain if opset.domain else "ai.onnx"
60
+ try:
61
+ key = (domain, opset.version)
62
+ if key not in helper.OP_SET_ID_VERSION_MAP:
63
+ error_message = (
64
+ "Skipping this test as only released onnx opsets are supported."
65
+ "To run this test set env variable ALLOW_RELEASED_ONNX_OPSET_ONLY to 0."
66
+ f" Got Domain '{domain}' version '{opset.version}'."
67
+ )
68
+ return False, error_message
69
+ except AttributeError:
70
+ # for some CI pipelines accessing helper.OP_SET_ID_VERSION_MAP
71
+ # is generating attribute error. TODO investigate the pipelines to
72
+ # fix this error. Falling back to a simple version check when this error is encountered
73
+ if (domain == "ai.onnx" and opset.version > 12) or (domain == "ai.ommx.ml" and opset.version > 2):
74
+ error_message = (
75
+ "Skipping this test as only released onnx opsets are supported."
76
+ "To run this test set env variable ALLOW_RELEASED_ONNX_OPSET_ONLY to 0."
77
+ f" Got Domain '{domain}' version '{opset.version}'."
78
+ )
79
+ return False, error_message
80
+ return True, ""
81
+
82
+ @classmethod
83
+ def supports_device(cls, device):
84
+ """
85
+ Check whether the backend is compiled with particular device support.
86
+ In particular it's used in the testing suite.
87
+ """
88
+ if device == "CUDA":
89
+ device = "GPU"
90
+ return device in get_device()
91
+
92
+ @classmethod
93
+ def prepare(cls, model, device=None, **kwargs):
94
+ """
95
+ Load the model and creates a :class:`onnxruntime.InferenceSession`
96
+ ready to be used as a backend.
97
+
98
+ :param model: ModelProto (returned by `onnx.load`),
99
+ string for a filename or bytes for a serialized model
100
+ :param device: requested device for the computation,
101
+ None means the default one which depends on
102
+ the compilation settings
103
+ :param kwargs: see :class:`onnxruntime.SessionOptions`
104
+ :return: :class:`onnxruntime.InferenceSession`
105
+ """
106
+ if isinstance(model, OnnxRuntimeBackendRep):
107
+ return model
108
+ elif isinstance(model, InferenceSession):
109
+ return OnnxRuntimeBackendRep(model)
110
+ elif isinstance(model, (str, bytes)):
111
+ options = SessionOptions()
112
+ for k, v in kwargs.items():
113
+ if hasattr(options, k):
114
+ setattr(options, k, v)
115
+
116
+ excluded_providers = os.getenv("ORT_ONNX_BACKEND_EXCLUDE_PROVIDERS", default="").split(",")
117
+ providers = [x for x in get_available_providers() if (x not in excluded_providers)]
118
+
119
+ inf = InferenceSession(model, sess_options=options, providers=providers)
120
+ # backend API is primarily used for ONNX test/validation. As such, we should disable session.run() fallback
121
+ # which may hide test failures.
122
+ inf.disable_fallback()
123
+ if device is not None and not cls.supports_device(device):
124
+ raise RuntimeError(f"Incompatible device expected '{device}', got '{get_device()}'")
125
+ return cls.prepare(inf, device, **kwargs)
126
+ else:
127
+ # type: ModelProto
128
+ # check_model serializes the model anyways, so serialize the model once here
129
+ # and reuse it below in the cls.prepare call to avoid an additional serialization
130
+ # only works with onnx >= 1.10.0 hence the version check
131
+ onnx_version = packaging.version.parse(version.version) or packaging.version.Version("0")
132
+ onnx_supports_serialized_model_check = onnx_version.release >= (1, 10, 0)
133
+ bin_or_model = model.SerializeToString() if onnx_supports_serialized_model_check else model
134
+ check_model(bin_or_model)
135
+ opset_supported, error_message = cls.is_opset_supported(model)
136
+ if not opset_supported:
137
+ raise unittest.SkipTest(error_message)
138
+ # Now bin might be serialized, if it's not we need to serialize it otherwise we'll have
139
+ # an infinite recursive call
140
+ bin = bin_or_model
141
+ if not isinstance(bin, (str, bytes)):
142
+ bin = bin.SerializeToString()
143
+ return cls.prepare(bin, device, **kwargs)
144
+
145
+ @classmethod
146
+ def run_model(cls, model, inputs, device=None, **kwargs):
147
+ """
148
+ Compute the prediction.
149
+
150
+ :param model: :class:`onnxruntime.InferenceSession` returned
151
+ by function *prepare*
152
+ :param inputs: inputs
153
+ :param device: requested device for the computation,
154
+ None means the default one which depends on
155
+ the compilation settings
156
+ :param kwargs: see :class:`onnxruntime.RunOptions`
157
+ :return: predictions
158
+ """
159
+ rep = cls.prepare(model, device, **kwargs)
160
+ return rep.run(inputs, **kwargs)
161
+
162
+ @classmethod
163
+ def run_node(cls, node, inputs, device=None, outputs_info=None, **kwargs):
164
+ """
165
+ This method is not implemented as it is much more efficient
166
+ to run a whole model than every node independently.
167
+ """
168
+ raise NotImplementedError("It is much more efficient to run a whole model than every node independently.")
169
+
170
+
171
+ is_compatible = OnnxRuntimeBackend.is_compatible
172
+ prepare = OnnxRuntimeBackend.prepare
173
+ run = OnnxRuntimeBackend.run_model
174
+ supports_device = OnnxRuntimeBackend.supports_device
@@ -0,0 +1,53 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+ """
6
+ Implements ONNX's backend API.
7
+ """
8
+ from typing import Any, Tuple # noqa: F401
9
+
10
+ from onnx.backend.base import BackendRep
11
+
12
+ from onnxruntime import RunOptions
13
+
14
+
15
+ class OnnxRuntimeBackendRep(BackendRep):
16
+ """
17
+ Computes the prediction for a pipeline converted into
18
+ an :class:`onnxruntime.InferenceSession` node.
19
+ """
20
+
21
+ def __init__(self, session):
22
+ """
23
+ :param session: :class:`onnxruntime.InferenceSession`
24
+ """
25
+ self._session = session
26
+
27
+ def run(self, inputs, **kwargs): # type: (Any, **Any) -> Tuple[Any, ...]
28
+ """
29
+ Computes the prediction.
30
+ See :meth:`onnxruntime.InferenceSession.run`.
31
+ """
32
+
33
+ options = RunOptions()
34
+ for k, v in kwargs.items():
35
+ if hasattr(options, k):
36
+ setattr(options, k, v)
37
+
38
+ if isinstance(inputs, list):
39
+ inps = {}
40
+ for i, inp in enumerate(self._session.get_inputs()):
41
+ inps[inp.name] = inputs[i]
42
+ outs = self._session.run(None, inps, options)
43
+ if isinstance(outs, list):
44
+ return outs
45
+ else:
46
+ output_names = [o.name for o in self._session.get_outputs()]
47
+ return [outs[name] for name in output_names]
48
+ else:
49
+ inp = self._session.get_inputs()
50
+ if len(inp) != 1:
51
+ raise RuntimeError(f"Model expect {len(inp)} inputs")
52
+ inps = {inp[0].name: inputs}
53
+ return self._session.run(None, inps, options)
Binary file
@@ -0,0 +1,4 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
@@ -0,0 +1,7 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+
6
+ # This file can be modified by setup.py when building a manylinux2010 wheel
7
+ # When modified, it will preload some libraries needed for the python C extension
@@ -0,0 +1,33 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+ """
6
+ Ensure that dependencies are available and then load the extension module.
7
+ """
8
+ import os
9
+ import platform
10
+ import warnings
11
+
12
+ from . import _ld_preload # noqa: F401
13
+
14
+ if platform.system() == "Windows":
15
+ from . import version_info
16
+
17
+ # If on Windows, check if this import error is caused by the user not installing the 2019 VC Runtime
18
+ # The VC Redist installer usually puts the VC Runtime dlls in the System32 folder, but it may also be found
19
+ # in some other locations.
20
+ # TODO, we may want to try to load the VC Runtime dlls instead of checking if the hardcoded file path
21
+ # is valid, and raise ImportError if the load fails
22
+ if version_info.vs2019 and platform.architecture()[0] == "64bit":
23
+ system_root = os.getenv("SystemRoot") or "C:\\Windows"
24
+ if not os.path.isfile(os.path.join(system_root, "System32", "vcruntime140_1.dll")):
25
+ warnings.warn("Please install the 2019 Visual C++ runtime and then try again. "
26
+ "If you've installed the runtime in a non-standard location "
27
+ "(other than %SystemRoot%\\System32), "
28
+ "make sure it can be found by setting the correct path.")
29
+
30
+
31
+
32
+ from .onnxruntime_pybind11_state import * # noqa
33
+
@@ -0,0 +1,48 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ # This script helps converting .npz files to .onnx_adapter files
5
+
6
+ import argparse
7
+ import os
8
+ import sys
9
+
10
+ import numpy as np
11
+
12
+ import onnxruntime as ort
13
+
14
+
15
+ def get_args() -> argparse:
16
+ parser = argparse.ArgumentParser()
17
+ parser.add_argument("--npz_file_path", type=str, required=True)
18
+ parser.add_argument("--output_file_path", type=str, required=True)
19
+ parser.add_argument("--adapter_version", type=int, required=True)
20
+ parser.add_argument("--model_version", type=int, required=True)
21
+ return parser.parse_args()
22
+
23
+
24
+ def export_lora_parameters(
25
+ npz_file_path: os.PathLike, adapter_version: int, model_version: int, output_file_path: os.PathLike
26
+ ):
27
+ """The function converts lora parameters in npz to onnx_adapter format"""
28
+ adapter_format = ort.AdapterFormat()
29
+ adapter_format.set_adapter_version(adapter_version)
30
+ adapter_format.set_model_version(model_version)
31
+ name_to_ort_value = {}
32
+ with np.load(npz_file_path) as data:
33
+ for name, np_arr in data.items():
34
+ ort_value = ort.OrtValue.ortvalue_from_numpy(np_arr)
35
+ name_to_ort_value[name] = ort_value
36
+
37
+ adapter_format.set_parameters(name_to_ort_value)
38
+ adapter_format.export_adapter(output_file_path)
39
+
40
+
41
+ def main() -> int:
42
+ args = get_args()
43
+ export_lora_parameters(args.npz_file_path, args.adapter_version, args.model_version, args.output_file_path)
44
+ return 0
45
+
46
+
47
+ if __name__ == "__main__":
48
+ sys.exit(main())
Binary file
@@ -0,0 +1,47 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+ import ctypes
6
+ import sys
7
+ import warnings
8
+
9
+
10
+ def find_cudart_versions(build_env=False, build_cuda_version=None):
11
+ # ctypes.CDLL and ctypes.util.find_library load the latest installed library.
12
+ # it may not the the library that would be loaded by onnxruntime.
13
+ # for example, in an environment with Cuda 11.1 and subsequently
14
+ # conda cudatoolkit 10.2.89 installed. ctypes will find cudart 10.2. however,
15
+ # onnxruntime built with Cuda 11.1 will find and load cudart for Cuda 11.1.
16
+ # for the above reason, we need find all versions in the environment and
17
+ # only give warnings if the expected cuda version is not found.
18
+ # in onnxruntime build environment, we expected only one Cuda version.
19
+ if not sys.platform.startswith("linux"):
20
+ warnings.warn("find_cudart_versions only works on Linux")
21
+ return None
22
+
23
+ cudart_possible_versions = {None, build_cuda_version}
24
+
25
+ def get_cudart_version(find_cudart_version=None):
26
+ cudart_lib_filename = "libcudart.so"
27
+ if find_cudart_version:
28
+ cudart_lib_filename = cudart_lib_filename + "." + find_cudart_version
29
+
30
+ try:
31
+ cudart = ctypes.CDLL(cudart_lib_filename)
32
+ cudart.cudaRuntimeGetVersion.restype = int
33
+ cudart.cudaRuntimeGetVersion.argtypes = [ctypes.POINTER(ctypes.c_int)]
34
+ version = ctypes.c_int()
35
+ status = cudart.cudaRuntimeGetVersion(ctypes.byref(version))
36
+ if status != 0:
37
+ return None
38
+ except Exception:
39
+ return None
40
+
41
+ return version.value
42
+
43
+ # use set to avoid duplications
44
+ cudart_found_versions = {get_cudart_version(cudart_version) for cudart_version in cudart_possible_versions}
45
+
46
+ # convert to list and remove None
47
+ return [ver for ver in cudart_found_versions if ver]