onnxruntime-directml 1.20.0__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (305) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6508 -0
  4. onnxruntime/__init__.py +78 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +174 -0
  7. onnxruntime/backend/backend_rep.py +53 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  13. onnxruntime/capi/onnxruntime.dll +0 -0
  14. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  15. onnxruntime/capi/onnxruntime_inference_collection.py +1108 -0
  16. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  17. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  18. onnxruntime/capi/onnxruntime_validation.py +150 -0
  19. onnxruntime/capi/version_info.py +2 -0
  20. onnxruntime/datasets/__init__.py +17 -0
  21. onnxruntime/datasets/logreg_iris.onnx +0 -0
  22. onnxruntime/datasets/mul_1.onnx +0 -0
  23. onnxruntime/datasets/sigmoid.onnx +13 -0
  24. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  27. onnxruntime/quantization/__init__.py +16 -0
  28. onnxruntime/quantization/base_quantizer.py +532 -0
  29. onnxruntime/quantization/calibrate.py +1245 -0
  30. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  31. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  32. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  33. onnxruntime/quantization/execution_providers/qnn/preprocess.py +307 -0
  34. onnxruntime/quantization/execution_providers/qnn/quant_config.py +387 -0
  35. onnxruntime/quantization/fusions/__init__.py +3 -0
  36. onnxruntime/quantization/fusions/fusion.py +311 -0
  37. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  38. onnxruntime/quantization/fusions/fusion_layernorm.py +135 -0
  39. onnxruntime/quantization/matmul_4bits_quantizer.py +1480 -0
  40. onnxruntime/quantization/matmul_bnb4_quantizer.py +240 -0
  41. onnxruntime/quantization/onnx_model.py +580 -0
  42. onnxruntime/quantization/onnx_quantizer.py +1008 -0
  43. onnxruntime/quantization/operators/__init__.py +2 -0
  44. onnxruntime/quantization/operators/activation.py +119 -0
  45. onnxruntime/quantization/operators/argmax.py +18 -0
  46. onnxruntime/quantization/operators/attention.py +73 -0
  47. onnxruntime/quantization/operators/base_operator.py +26 -0
  48. onnxruntime/quantization/operators/binary_op.py +72 -0
  49. onnxruntime/quantization/operators/concat.py +62 -0
  50. onnxruntime/quantization/operators/conv.py +258 -0
  51. onnxruntime/quantization/operators/direct_q8.py +78 -0
  52. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  53. onnxruntime/quantization/operators/gather.py +64 -0
  54. onnxruntime/quantization/operators/gavgpool.py +62 -0
  55. onnxruntime/quantization/operators/gemm.py +166 -0
  56. onnxruntime/quantization/operators/lstm.py +117 -0
  57. onnxruntime/quantization/operators/matmul.py +231 -0
  58. onnxruntime/quantization/operators/maxpool.py +34 -0
  59. onnxruntime/quantization/operators/norm.py +40 -0
  60. onnxruntime/quantization/operators/pad.py +100 -0
  61. onnxruntime/quantization/operators/pooling.py +67 -0
  62. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  63. onnxruntime/quantization/operators/resize.py +34 -0
  64. onnxruntime/quantization/operators/softmax.py +74 -0
  65. onnxruntime/quantization/operators/split.py +63 -0
  66. onnxruntime/quantization/operators/where.py +87 -0
  67. onnxruntime/quantization/preprocess.py +141 -0
  68. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  69. onnxruntime/quantization/qdq_quantizer.py +1187 -0
  70. onnxruntime/quantization/quant_utils.py +891 -0
  71. onnxruntime/quantization/quantize.py +748 -0
  72. onnxruntime/quantization/registry.py +106 -0
  73. onnxruntime/quantization/shape_inference.py +187 -0
  74. onnxruntime/quantization/tensor_quant_overrides.py +516 -0
  75. onnxruntime/tools/__init__.py +10 -0
  76. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  77. onnxruntime/tools/convert_onnx_models_to_ort.py +377 -0
  78. onnxruntime/tools/file_utils.py +46 -0
  79. onnxruntime/tools/logger.py +11 -0
  80. onnxruntime/tools/make_dynamic_shape_fixed.py +72 -0
  81. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  82. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +33 -0
  83. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  84. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  85. onnxruntime/tools/mobile_helpers/usability_checker.py +739 -0
  86. onnxruntime/tools/offline_tuning.py +169 -0
  87. onnxruntime/tools/onnx_model_utils.py +413 -0
  88. onnxruntime/tools/onnx_randomizer.py +85 -0
  89. onnxruntime/tools/onnxruntime_test.py +164 -0
  90. onnxruntime/tools/optimize_onnx_model.py +55 -0
  91. onnxruntime/tools/ort_format_model/__init__.py +25 -0
  92. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +663 -0
  93. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  94. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  95. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  96. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  97. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  98. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  99. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  100. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  140. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  141. onnxruntime/tools/ort_format_model/types.py +84 -0
  142. onnxruntime/tools/ort_format_model/utils.py +62 -0
  143. onnxruntime/tools/pytorch_export_contrib_ops.py +108 -0
  144. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  145. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  146. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  147. onnxruntime/tools/reduced_build_config_parser.py +202 -0
  148. onnxruntime/tools/symbolic_shape_infer.py +3016 -0
  149. onnxruntime/tools/update_onnx_opset.py +31 -0
  150. onnxruntime/transformers/__init__.py +8 -0
  151. onnxruntime/transformers/affinity_helper.py +40 -0
  152. onnxruntime/transformers/benchmark.py +944 -0
  153. onnxruntime/transformers/benchmark_helper.py +646 -0
  154. onnxruntime/transformers/bert_perf_test.py +634 -0
  155. onnxruntime/transformers/bert_test_data.py +642 -0
  156. onnxruntime/transformers/compare_bert_results.py +246 -0
  157. onnxruntime/transformers/constants.py +47 -0
  158. onnxruntime/transformers/convert_generation.py +3124 -0
  159. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  160. onnxruntime/transformers/convert_to_packing_mode.py +387 -0
  161. onnxruntime/transformers/dynamo_onnx_helper.py +104 -0
  162. onnxruntime/transformers/float16.py +501 -0
  163. onnxruntime/transformers/fusion_attention.py +1235 -0
  164. onnxruntime/transformers/fusion_attention_clip.py +257 -0
  165. onnxruntime/transformers/fusion_attention_sam2.py +534 -0
  166. onnxruntime/transformers/fusion_attention_unet.py +1304 -0
  167. onnxruntime/transformers/fusion_attention_vae.py +301 -0
  168. onnxruntime/transformers/fusion_bart_attention.py +640 -0
  169. onnxruntime/transformers/fusion_base.py +137 -0
  170. onnxruntime/transformers/fusion_bias_add.py +58 -0
  171. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  172. onnxruntime/transformers/fusion_biassplitgelu.py +111 -0
  173. onnxruntime/transformers/fusion_conformer_attention.py +143 -0
  174. onnxruntime/transformers/fusion_embedlayer.py +811 -0
  175. onnxruntime/transformers/fusion_fastgelu.py +360 -0
  176. onnxruntime/transformers/fusion_gelu.py +259 -0
  177. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  178. onnxruntime/transformers/fusion_gemmfastgelu.py +122 -0
  179. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  180. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  181. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  182. onnxruntime/transformers/fusion_group_norm.py +179 -0
  183. onnxruntime/transformers/fusion_layernorm.py +465 -0
  184. onnxruntime/transformers/fusion_nhwc_conv.py +100 -0
  185. onnxruntime/transformers/fusion_options.py +340 -0
  186. onnxruntime/transformers/fusion_qordered_attention.py +421 -0
  187. onnxruntime/transformers/fusion_qordered_gelu.py +119 -0
  188. onnxruntime/transformers/fusion_qordered_layernorm.py +123 -0
  189. onnxruntime/transformers/fusion_qordered_matmul.py +217 -0
  190. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  191. onnxruntime/transformers/fusion_reshape.py +173 -0
  192. onnxruntime/transformers/fusion_rotary_attention.py +1592 -0
  193. onnxruntime/transformers/fusion_shape.py +110 -0
  194. onnxruntime/transformers/fusion_simplified_layernorm.py +159 -0
  195. onnxruntime/transformers/fusion_skip_group_norm.py +255 -0
  196. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  197. onnxruntime/transformers/fusion_transpose.py +168 -0
  198. onnxruntime/transformers/fusion_utils.py +307 -0
  199. onnxruntime/transformers/huggingface_models.py +167 -0
  200. onnxruntime/transformers/import_utils.py +20 -0
  201. onnxruntime/transformers/io_binding_helper.py +442 -0
  202. onnxruntime/transformers/large_model_exporter.py +395 -0
  203. onnxruntime/transformers/machine_info.py +221 -0
  204. onnxruntime/transformers/metrics.py +164 -0
  205. onnxruntime/transformers/models/bart/__init__.py +12 -0
  206. onnxruntime/transformers/models/bart/export.py +98 -0
  207. onnxruntime/transformers/models/bert/__init__.py +12 -0
  208. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  209. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  210. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  211. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +561 -0
  212. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1032 -0
  213. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  214. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  215. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  216. onnxruntime/transformers/models/llama/__init__.py +12 -0
  217. onnxruntime/transformers/models/llama/benchmark.py +703 -0
  218. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  219. onnxruntime/transformers/models/llama/benchmark_e2e.py +606 -0
  220. onnxruntime/transformers/models/llama/convert_to_onnx.py +1027 -0
  221. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  222. onnxruntime/transformers/models/llama/llama_inputs.py +503 -0
  223. onnxruntime/transformers/models/llama/llama_parity.py +309 -0
  224. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  225. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  226. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  227. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  228. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  229. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  230. onnxruntime/transformers/models/longformer/longformer_helper.py +77 -0
  231. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  232. onnxruntime/transformers/models/phi2/convert_to_onnx.py +576 -0
  233. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  234. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  235. onnxruntime/transformers/models/sam2/benchmark_sam2.py +625 -0
  236. onnxruntime/transformers/models/sam2/convert_to_onnx.py +260 -0
  237. onnxruntime/transformers/models/sam2/image_decoder.py +273 -0
  238. onnxruntime/transformers/models/sam2/image_encoder.py +186 -0
  239. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  240. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  241. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  242. onnxruntime/transformers/models/sam2/sam2_demo.py +322 -0
  243. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +280 -0
  244. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  245. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  246. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1429 -0
  247. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  248. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +102 -0
  249. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +268 -0
  250. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  251. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1319 -0
  252. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1181 -0
  253. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +296 -0
  254. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +388 -0
  255. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  256. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  257. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  258. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +350 -0
  259. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  260. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  261. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  262. onnxruntime/transformers/models/t5/__init__.py +12 -0
  263. onnxruntime/transformers/models/t5/convert_to_onnx.py +278 -0
  264. onnxruntime/transformers/models/t5/past_helper.py +150 -0
  265. onnxruntime/transformers/models/t5/t5_decoder.py +438 -0
  266. onnxruntime/transformers/models/t5/t5_encoder.py +171 -0
  267. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +299 -0
  268. onnxruntime/transformers/models/t5/t5_helper.py +272 -0
  269. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  270. onnxruntime/transformers/models/whisper/benchmark.py +610 -0
  271. onnxruntime/transformers/models/whisper/benchmark_all.py +528 -0
  272. onnxruntime/transformers/models/whisper/convert_to_onnx.py +536 -0
  273. onnxruntime/transformers/models/whisper/whisper_chain.py +329 -0
  274. onnxruntime/transformers/models/whisper/whisper_decoder.py +402 -0
  275. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  276. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +306 -0
  277. onnxruntime/transformers/models/whisper/whisper_helper.py +524 -0
  278. onnxruntime/transformers/models/whisper/whisper_openai_helper.py +84 -0
  279. onnxruntime/transformers/onnx_exporter.py +717 -0
  280. onnxruntime/transformers/onnx_model.py +1569 -0
  281. onnxruntime/transformers/onnx_model_bart.py +142 -0
  282. onnxruntime/transformers/onnx_model_bert.py +481 -0
  283. onnxruntime/transformers/onnx_model_bert_keras.py +475 -0
  284. onnxruntime/transformers/onnx_model_bert_tf.py +589 -0
  285. onnxruntime/transformers/onnx_model_clip.py +40 -0
  286. onnxruntime/transformers/onnx_model_conformer.py +33 -0
  287. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  288. onnxruntime/transformers/onnx_model_phi.py +930 -0
  289. onnxruntime/transformers/onnx_model_sam2.py +138 -0
  290. onnxruntime/transformers/onnx_model_t5.py +791 -0
  291. onnxruntime/transformers/onnx_model_tnlr.py +227 -0
  292. onnxruntime/transformers/onnx_model_unet.py +259 -0
  293. onnxruntime/transformers/onnx_model_vae.py +43 -0
  294. onnxruntime/transformers/onnx_utils.py +55 -0
  295. onnxruntime/transformers/optimizer.py +612 -0
  296. onnxruntime/transformers/profiler.py +725 -0
  297. onnxruntime/transformers/quantize_helper.py +76 -0
  298. onnxruntime/transformers/shape_infer_helper.py +122 -0
  299. onnxruntime/transformers/shape_optimizer.py +401 -0
  300. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  301. onnxruntime_directml-1.20.0.dist-info/METADATA +187 -0
  302. onnxruntime_directml-1.20.0.dist-info/RECORD +305 -0
  303. onnxruntime_directml-1.20.0.dist-info/WHEEL +5 -0
  304. onnxruntime_directml-1.20.0.dist-info/entry_points.txt +2 -0
  305. onnxruntime_directml-1.20.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,150 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+ """
6
+ Check OS requirements for ONNX Runtime Python Bindings.
7
+ """
8
+ import linecache
9
+ import platform
10
+ import warnings
11
+
12
+
13
+ def check_distro_info():
14
+ __my_distro__ = ""
15
+ __my_distro_ver__ = ""
16
+ __my_system__ = platform.system().lower()
17
+
18
+ __OS_RELEASE_FILE__ = "/etc/os-release" # noqa: N806
19
+ __LSB_RELEASE_FILE__ = "/etc/lsb-release" # noqa: N806
20
+
21
+ if __my_system__ == "windows":
22
+ __my_distro__ = __my_system__
23
+ __my_distro_ver__ = platform.release().lower()
24
+
25
+ if __my_distro_ver__ not in ["10", "11"]:
26
+ warnings.warn(
27
+ f"Unsupported Windows version ({__my_distro_ver__}). ONNX Runtime supports Windows 10 and above, only."
28
+ )
29
+ elif __my_system__ == "linux":
30
+ """Although the 'platform' python module for getting Distro information works well on standard OS images
31
+ running on real hardware, it is not accurate when running on Azure VMs, Git Bash, Cygwin, etc.
32
+ The returned values for release and version are unpredictable for virtualized or emulated environments.
33
+ /etc/os-release and /etc/lsb_release files, on the other hand, are guaranteed to exist and have standard values
34
+ in all OSes supported by onnxruntime. The former is the current standard file to check OS info and the latter
35
+ is its predecessor.
36
+ """
37
+ # Newer systems have /etc/os-release with relevant distro info
38
+ __my_distro__ = linecache.getline(__OS_RELEASE_FILE__, 3)[3:-1]
39
+ __my_distro_ver__ = linecache.getline(__OS_RELEASE_FILE__, 6)[12:-2]
40
+
41
+ # Older systems may have /etc/os-release instead
42
+ if not __my_distro__:
43
+ __my_distro__ = linecache.getline(__LSB_RELEASE_FILE__, 1)[11:-1]
44
+ __my_distro_ver__ = linecache.getline(__LSB_RELEASE_FILE__, 2)[16:-1]
45
+
46
+ # Instead of trying to parse distro specific files,
47
+ # warn the user ONNX Runtime may not work out of the box
48
+ __my_distro__ = __my_distro__.lower()
49
+ __my_distro_ver__ = __my_distro_ver__.lower()
50
+ elif __my_system__ == "darwin":
51
+ __my_distro__ = __my_system__
52
+ __my_distro_ver__ = platform.release().lower()
53
+
54
+ if int(__my_distro_ver__.split(".")[0]) < 11:
55
+ warnings.warn(
56
+ f"Unsupported macOS version ({__my_distro_ver__}). ONNX Runtime supports macOS 11.0 or later."
57
+ )
58
+ elif __my_system__ == "aix":
59
+ import subprocess
60
+
61
+ returned_output = subprocess.check_output("oslevel")
62
+ __my_distro_ver__str = returned_output.decode("utf-8")
63
+ __my_distro_ver = __my_distro_ver__str[:3]
64
+ else:
65
+ warnings.warn(
66
+ f"Unsupported platform ({__my_system__}). ONNX Runtime supports Linux, macOS, AIX and Windows platforms, only."
67
+ )
68
+
69
+
70
+ def validate_build_package_info():
71
+ import_ortmodule_exception = None
72
+
73
+ has_ortmodule = False
74
+ try:
75
+ from onnxruntime.training.ortmodule import ORTModule # noqa: F401
76
+
77
+ has_ortmodule = True
78
+ except ImportError:
79
+ # ORTModule not present
80
+ has_ortmodule = False
81
+ except Exception as e:
82
+ # this may happen if Cuda is not installed, we want to raise it after
83
+ # for any exception other than not having ortmodule, we want to continue
84
+ # device version validation and raise the exception after.
85
+ try:
86
+ from onnxruntime.training.ortmodule._fallback import ORTModuleInitException
87
+
88
+ if isinstance(e, ORTModuleInitException):
89
+ # ORTModule is present but not ready to run yet
90
+ has_ortmodule = True
91
+ except Exception:
92
+ # ORTModule not present
93
+ has_ortmodule = False
94
+
95
+ if not has_ortmodule:
96
+ import_ortmodule_exception = e
97
+
98
+ package_name = ""
99
+ version = ""
100
+ cuda_version = ""
101
+
102
+ if has_ortmodule:
103
+ try:
104
+ # collect onnxruntime package name, version, and cuda version
105
+ from .build_and_package_info import __version__ as version
106
+ from .build_and_package_info import package_name
107
+
108
+ try: # noqa: SIM105
109
+ from .build_and_package_info import cuda_version
110
+ except Exception:
111
+ pass
112
+
113
+ if cuda_version:
114
+ # collect cuda library build info. the library info may not be available
115
+ # when the build environment has none or multiple libraries installed
116
+ try:
117
+ from .build_and_package_info import cudart_version
118
+ except Exception:
119
+ warnings.warn("WARNING: failed to get cudart_version from onnxruntime build info.")
120
+ cudart_version = None
121
+
122
+ def print_build_package_info():
123
+ warnings.warn(f"onnxruntime training package info: package_name: {package_name}")
124
+ warnings.warn(f"onnxruntime training package info: __version__: {version}")
125
+ warnings.warn(f"onnxruntime training package info: cuda_version: {cuda_version}")
126
+ warnings.warn(f"onnxruntime build info: cudart_version: {cudart_version}")
127
+
128
+ # collection cuda library info from current environment.
129
+ from onnxruntime.capi.onnxruntime_collect_build_info import find_cudart_versions
130
+
131
+ local_cudart_versions = find_cudart_versions(build_env=False, build_cuda_version=cuda_version)
132
+ if cudart_version and local_cudart_versions and cudart_version not in local_cudart_versions:
133
+ print_build_package_info()
134
+ warnings.warn("WARNING: failed to find cudart version that matches onnxruntime build info")
135
+ warnings.warn(f"WARNING: found cudart versions: {local_cudart_versions}")
136
+ else:
137
+ # TODO: rcom
138
+ pass
139
+
140
+ except Exception as e:
141
+ warnings.warn("WARNING: failed to collect onnxruntime version and build info")
142
+ print(e)
143
+
144
+ if import_ortmodule_exception:
145
+ raise import_ortmodule_exception
146
+
147
+ return has_ortmodule, package_name, version, cuda_version
148
+
149
+
150
+ has_ortmodule, package_name, version, cuda_version = validate_build_package_info()
@@ -0,0 +1,2 @@
1
+ use_cuda = False
2
+ vs2019 = False
@@ -0,0 +1,17 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+ """
4
+ Short examples used in the documentation.
5
+ """
6
+ import os
7
+
8
+
9
+ def get_example(name):
10
+ """
11
+ Retrieves the absolute file name of an example.
12
+ """
13
+ this = os.path.abspath(os.path.dirname(__file__))
14
+ full = os.path.join(this, name)
15
+ if not os.path.exists(full):
16
+ raise FileNotFoundError(f"Unable to find example '{name}'")
17
+ return full
Binary file
Binary file
@@ -0,0 +1,13 @@
1
+  backend-test:Q
2
+ 
3
+ xy"Sigmoid test_sigmoidZ
4
+ x
5
+ 
6
+ 
7
+ 
8
+ b
9
+ y
10
+ 
11
+ 
12
+ 
13
+ B
@@ -0,0 +1,78 @@
1
+ # automatically generated by the FlatBuffers compiler, do not modify
2
+
3
+ # namespace: CalTableFlatBuffers
4
+
5
+ import flatbuffers
6
+ from flatbuffers.compat import import_numpy
7
+
8
+ np = import_numpy()
9
+
10
+
11
+ class KeyValue:
12
+ __slots__ = ["_tab"]
13
+
14
+ @classmethod
15
+ def GetRootAs(cls, buf, offset=0): # noqa: N802
16
+ n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
17
+ x = KeyValue()
18
+ x.Init(buf, n + offset)
19
+ return x
20
+
21
+ @classmethod
22
+ def GetRootAsKeyValue(cls, buf, offset=0): # noqa: N802
23
+ """This method is deprecated. Please switch to GetRootAs."""
24
+ return cls.GetRootAs(buf, offset)
25
+
26
+ # KeyValue
27
+ def Init(self, buf, pos): # noqa: N802
28
+ self._tab = flatbuffers.table.Table(buf, pos)
29
+
30
+ # KeyValue
31
+ def Key(self): # noqa: N802
32
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
33
+ if o != 0:
34
+ return self._tab.String(o + self._tab.Pos)
35
+ return None
36
+
37
+ # KeyValue
38
+ def Value(self): # noqa: N802
39
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
40
+ if o != 0:
41
+ return self._tab.String(o + self._tab.Pos)
42
+ return None
43
+
44
+
45
+ def Start(builder): # noqa: N802
46
+ builder.StartObject(2)
47
+
48
+
49
+ def KeyValueStart(builder): # noqa: N802
50
+ """This method is deprecated. Please switch to Start."""
51
+ return Start(builder)
52
+
53
+
54
+ def AddKey(builder, key): # noqa: N802
55
+ builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(key), 0)
56
+
57
+
58
+ def KeyValueAddKey(builder, key): # noqa: N802
59
+ """This method is deprecated. Please switch to AddKey."""
60
+ return AddKey(builder, key)
61
+
62
+
63
+ def AddValue(builder, value): # noqa: N802
64
+ builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(value), 0)
65
+
66
+
67
+ def KeyValueAddValue(builder, value): # noqa: N802
68
+ """This method is deprecated. Please switch to AddValue."""
69
+ return AddValue(builder, value)
70
+
71
+
72
+ def End(builder): # noqa: N802
73
+ return builder.EndObject()
74
+
75
+
76
+ def KeyValueEnd(builder): # noqa: N802
77
+ """This method is deprecated. Please switch to End."""
78
+ return End(builder)
@@ -0,0 +1,90 @@
1
+ # automatically generated by the FlatBuffers compiler, do not modify
2
+
3
+ # namespace: CalTableFlatBuffers
4
+
5
+ import flatbuffers
6
+ from flatbuffers.compat import import_numpy
7
+
8
+ np = import_numpy()
9
+
10
+
11
+ class TrtTable:
12
+ __slots__ = ["_tab"]
13
+
14
+ @classmethod
15
+ def GetRootAs(cls, buf, offset=0): # noqa: N802
16
+ n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
17
+ x = TrtTable()
18
+ x.Init(buf, n + offset)
19
+ return x
20
+
21
+ @classmethod
22
+ def GetRootAsTrtTable(cls, buf, offset=0): # noqa: N802
23
+ """This method is deprecated. Please switch to GetRootAs."""
24
+ return cls.GetRootAs(buf, offset)
25
+
26
+ # TrtTable
27
+ def Init(self, buf, pos): # noqa: N802
28
+ self._tab = flatbuffers.table.Table(buf, pos)
29
+
30
+ # TrtTable
31
+ def Dict(self, j): # noqa: N802
32
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
33
+ if o != 0:
34
+ x = self._tab.Vector(o)
35
+ x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * 4
36
+ x = self._tab.Indirect(x)
37
+ from onnxruntime.quantization.CalTableFlatBuffers.KeyValue import KeyValue
38
+
39
+ obj = KeyValue()
40
+ obj.Init(self._tab.Bytes, x)
41
+ return obj
42
+ return None
43
+
44
+ # TrtTable
45
+ def DictLength(self): # noqa: N802
46
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
47
+ if o != 0:
48
+ return self._tab.VectorLen(o)
49
+ return 0
50
+
51
+ # TrtTable
52
+ def DictIsNone(self): # noqa: N802
53
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
54
+ return o == 0
55
+
56
+
57
+ def Start(builder): # noqa: N802
58
+ builder.StartObject(1)
59
+
60
+
61
+ def TrtTableStart(builder): # noqa: N802
62
+ """This method is deprecated. Please switch to Start."""
63
+ return Start(builder)
64
+
65
+
66
+ def AddDict(builder, dict): # noqa: N802
67
+ builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(dict), 0)
68
+
69
+
70
+ def TrtTableAddDict(builder, dict): # noqa: N802
71
+ """This method is deprecated. Please switch to AddDict."""
72
+ return AddDict(builder, dict)
73
+
74
+
75
+ def StartDictVector(builder, numElems): # noqa: N802
76
+ return builder.StartVector(4, numElems, 4)
77
+
78
+
79
+ def TrtTableStartDictVector(builder, numElems): # noqa: N802
80
+ """This method is deprecated. Please switch to Start."""
81
+ return StartDictVector(builder, numElems)
82
+
83
+
84
+ def End(builder): # noqa: N802
85
+ return builder.EndObject()
86
+
87
+
88
+ def TrtTableEnd(builder): # noqa: N802
89
+ """This method is deprecated. Please switch to End."""
90
+ return End(builder)
@@ -0,0 +1,16 @@
1
+ from .calibrate import ( # noqa: F401
2
+ CalibraterBase,
3
+ CalibrationDataReader,
4
+ CalibrationMethod,
5
+ MinMaxCalibrater,
6
+ create_calibrator,
7
+ )
8
+ from .qdq_quantizer import QDQQuantizer # noqa: F401
9
+ from .quant_utils import QuantFormat, QuantType, write_calibration_table # noqa: F401
10
+ from .quantize import DynamicQuantConfig # noqa: F401
11
+ from .quantize import QuantizationMode # noqa: F401
12
+ from .quantize import StaticQuantConfig # noqa: F401
13
+ from .quantize import quantize # noqa: F401
14
+ from .quantize import quantize_dynamic # noqa: F401
15
+ from .quantize import quantize_static # noqa: F401
16
+ from .shape_inference import quant_pre_process # noqa: F401