onnxruntime-directml 1.20.0__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (305) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6508 -0
  4. onnxruntime/__init__.py +78 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +174 -0
  7. onnxruntime/backend/backend_rep.py +53 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  13. onnxruntime/capi/onnxruntime.dll +0 -0
  14. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  15. onnxruntime/capi/onnxruntime_inference_collection.py +1108 -0
  16. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  17. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  18. onnxruntime/capi/onnxruntime_validation.py +150 -0
  19. onnxruntime/capi/version_info.py +2 -0
  20. onnxruntime/datasets/__init__.py +17 -0
  21. onnxruntime/datasets/logreg_iris.onnx +0 -0
  22. onnxruntime/datasets/mul_1.onnx +0 -0
  23. onnxruntime/datasets/sigmoid.onnx +13 -0
  24. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  27. onnxruntime/quantization/__init__.py +16 -0
  28. onnxruntime/quantization/base_quantizer.py +532 -0
  29. onnxruntime/quantization/calibrate.py +1245 -0
  30. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  31. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  32. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  33. onnxruntime/quantization/execution_providers/qnn/preprocess.py +307 -0
  34. onnxruntime/quantization/execution_providers/qnn/quant_config.py +387 -0
  35. onnxruntime/quantization/fusions/__init__.py +3 -0
  36. onnxruntime/quantization/fusions/fusion.py +311 -0
  37. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  38. onnxruntime/quantization/fusions/fusion_layernorm.py +135 -0
  39. onnxruntime/quantization/matmul_4bits_quantizer.py +1480 -0
  40. onnxruntime/quantization/matmul_bnb4_quantizer.py +240 -0
  41. onnxruntime/quantization/onnx_model.py +580 -0
  42. onnxruntime/quantization/onnx_quantizer.py +1008 -0
  43. onnxruntime/quantization/operators/__init__.py +2 -0
  44. onnxruntime/quantization/operators/activation.py +119 -0
  45. onnxruntime/quantization/operators/argmax.py +18 -0
  46. onnxruntime/quantization/operators/attention.py +73 -0
  47. onnxruntime/quantization/operators/base_operator.py +26 -0
  48. onnxruntime/quantization/operators/binary_op.py +72 -0
  49. onnxruntime/quantization/operators/concat.py +62 -0
  50. onnxruntime/quantization/operators/conv.py +258 -0
  51. onnxruntime/quantization/operators/direct_q8.py +78 -0
  52. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  53. onnxruntime/quantization/operators/gather.py +64 -0
  54. onnxruntime/quantization/operators/gavgpool.py +62 -0
  55. onnxruntime/quantization/operators/gemm.py +166 -0
  56. onnxruntime/quantization/operators/lstm.py +117 -0
  57. onnxruntime/quantization/operators/matmul.py +231 -0
  58. onnxruntime/quantization/operators/maxpool.py +34 -0
  59. onnxruntime/quantization/operators/norm.py +40 -0
  60. onnxruntime/quantization/operators/pad.py +100 -0
  61. onnxruntime/quantization/operators/pooling.py +67 -0
  62. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  63. onnxruntime/quantization/operators/resize.py +34 -0
  64. onnxruntime/quantization/operators/softmax.py +74 -0
  65. onnxruntime/quantization/operators/split.py +63 -0
  66. onnxruntime/quantization/operators/where.py +87 -0
  67. onnxruntime/quantization/preprocess.py +141 -0
  68. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  69. onnxruntime/quantization/qdq_quantizer.py +1187 -0
  70. onnxruntime/quantization/quant_utils.py +891 -0
  71. onnxruntime/quantization/quantize.py +748 -0
  72. onnxruntime/quantization/registry.py +106 -0
  73. onnxruntime/quantization/shape_inference.py +187 -0
  74. onnxruntime/quantization/tensor_quant_overrides.py +516 -0
  75. onnxruntime/tools/__init__.py +10 -0
  76. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  77. onnxruntime/tools/convert_onnx_models_to_ort.py +377 -0
  78. onnxruntime/tools/file_utils.py +46 -0
  79. onnxruntime/tools/logger.py +11 -0
  80. onnxruntime/tools/make_dynamic_shape_fixed.py +72 -0
  81. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  82. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +33 -0
  83. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  84. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  85. onnxruntime/tools/mobile_helpers/usability_checker.py +739 -0
  86. onnxruntime/tools/offline_tuning.py +169 -0
  87. onnxruntime/tools/onnx_model_utils.py +413 -0
  88. onnxruntime/tools/onnx_randomizer.py +85 -0
  89. onnxruntime/tools/onnxruntime_test.py +164 -0
  90. onnxruntime/tools/optimize_onnx_model.py +55 -0
  91. onnxruntime/tools/ort_format_model/__init__.py +25 -0
  92. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +663 -0
  93. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  94. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  95. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  96. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  97. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  98. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  99. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  100. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  140. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  141. onnxruntime/tools/ort_format_model/types.py +84 -0
  142. onnxruntime/tools/ort_format_model/utils.py +62 -0
  143. onnxruntime/tools/pytorch_export_contrib_ops.py +108 -0
  144. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  145. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  146. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  147. onnxruntime/tools/reduced_build_config_parser.py +202 -0
  148. onnxruntime/tools/symbolic_shape_infer.py +3016 -0
  149. onnxruntime/tools/update_onnx_opset.py +31 -0
  150. onnxruntime/transformers/__init__.py +8 -0
  151. onnxruntime/transformers/affinity_helper.py +40 -0
  152. onnxruntime/transformers/benchmark.py +944 -0
  153. onnxruntime/transformers/benchmark_helper.py +646 -0
  154. onnxruntime/transformers/bert_perf_test.py +634 -0
  155. onnxruntime/transformers/bert_test_data.py +642 -0
  156. onnxruntime/transformers/compare_bert_results.py +246 -0
  157. onnxruntime/transformers/constants.py +47 -0
  158. onnxruntime/transformers/convert_generation.py +3124 -0
  159. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  160. onnxruntime/transformers/convert_to_packing_mode.py +387 -0
  161. onnxruntime/transformers/dynamo_onnx_helper.py +104 -0
  162. onnxruntime/transformers/float16.py +501 -0
  163. onnxruntime/transformers/fusion_attention.py +1235 -0
  164. onnxruntime/transformers/fusion_attention_clip.py +257 -0
  165. onnxruntime/transformers/fusion_attention_sam2.py +534 -0
  166. onnxruntime/transformers/fusion_attention_unet.py +1304 -0
  167. onnxruntime/transformers/fusion_attention_vae.py +301 -0
  168. onnxruntime/transformers/fusion_bart_attention.py +640 -0
  169. onnxruntime/transformers/fusion_base.py +137 -0
  170. onnxruntime/transformers/fusion_bias_add.py +58 -0
  171. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  172. onnxruntime/transformers/fusion_biassplitgelu.py +111 -0
  173. onnxruntime/transformers/fusion_conformer_attention.py +143 -0
  174. onnxruntime/transformers/fusion_embedlayer.py +811 -0
  175. onnxruntime/transformers/fusion_fastgelu.py +360 -0
  176. onnxruntime/transformers/fusion_gelu.py +259 -0
  177. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  178. onnxruntime/transformers/fusion_gemmfastgelu.py +122 -0
  179. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  180. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  181. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  182. onnxruntime/transformers/fusion_group_norm.py +179 -0
  183. onnxruntime/transformers/fusion_layernorm.py +465 -0
  184. onnxruntime/transformers/fusion_nhwc_conv.py +100 -0
  185. onnxruntime/transformers/fusion_options.py +340 -0
  186. onnxruntime/transformers/fusion_qordered_attention.py +421 -0
  187. onnxruntime/transformers/fusion_qordered_gelu.py +119 -0
  188. onnxruntime/transformers/fusion_qordered_layernorm.py +123 -0
  189. onnxruntime/transformers/fusion_qordered_matmul.py +217 -0
  190. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  191. onnxruntime/transformers/fusion_reshape.py +173 -0
  192. onnxruntime/transformers/fusion_rotary_attention.py +1592 -0
  193. onnxruntime/transformers/fusion_shape.py +110 -0
  194. onnxruntime/transformers/fusion_simplified_layernorm.py +159 -0
  195. onnxruntime/transformers/fusion_skip_group_norm.py +255 -0
  196. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  197. onnxruntime/transformers/fusion_transpose.py +168 -0
  198. onnxruntime/transformers/fusion_utils.py +307 -0
  199. onnxruntime/transformers/huggingface_models.py +167 -0
  200. onnxruntime/transformers/import_utils.py +20 -0
  201. onnxruntime/transformers/io_binding_helper.py +442 -0
  202. onnxruntime/transformers/large_model_exporter.py +395 -0
  203. onnxruntime/transformers/machine_info.py +221 -0
  204. onnxruntime/transformers/metrics.py +164 -0
  205. onnxruntime/transformers/models/bart/__init__.py +12 -0
  206. onnxruntime/transformers/models/bart/export.py +98 -0
  207. onnxruntime/transformers/models/bert/__init__.py +12 -0
  208. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  209. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  210. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  211. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +561 -0
  212. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1032 -0
  213. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  214. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  215. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  216. onnxruntime/transformers/models/llama/__init__.py +12 -0
  217. onnxruntime/transformers/models/llama/benchmark.py +703 -0
  218. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  219. onnxruntime/transformers/models/llama/benchmark_e2e.py +606 -0
  220. onnxruntime/transformers/models/llama/convert_to_onnx.py +1027 -0
  221. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  222. onnxruntime/transformers/models/llama/llama_inputs.py +503 -0
  223. onnxruntime/transformers/models/llama/llama_parity.py +309 -0
  224. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  225. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  226. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  227. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  228. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  229. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  230. onnxruntime/transformers/models/longformer/longformer_helper.py +77 -0
  231. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  232. onnxruntime/transformers/models/phi2/convert_to_onnx.py +576 -0
  233. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  234. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  235. onnxruntime/transformers/models/sam2/benchmark_sam2.py +625 -0
  236. onnxruntime/transformers/models/sam2/convert_to_onnx.py +260 -0
  237. onnxruntime/transformers/models/sam2/image_decoder.py +273 -0
  238. onnxruntime/transformers/models/sam2/image_encoder.py +186 -0
  239. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  240. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  241. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  242. onnxruntime/transformers/models/sam2/sam2_demo.py +322 -0
  243. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +280 -0
  244. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  245. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  246. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1429 -0
  247. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  248. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +102 -0
  249. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +268 -0
  250. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  251. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1319 -0
  252. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1181 -0
  253. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +296 -0
  254. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +388 -0
  255. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  256. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  257. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  258. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +350 -0
  259. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  260. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  261. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  262. onnxruntime/transformers/models/t5/__init__.py +12 -0
  263. onnxruntime/transformers/models/t5/convert_to_onnx.py +278 -0
  264. onnxruntime/transformers/models/t5/past_helper.py +150 -0
  265. onnxruntime/transformers/models/t5/t5_decoder.py +438 -0
  266. onnxruntime/transformers/models/t5/t5_encoder.py +171 -0
  267. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +299 -0
  268. onnxruntime/transformers/models/t5/t5_helper.py +272 -0
  269. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  270. onnxruntime/transformers/models/whisper/benchmark.py +610 -0
  271. onnxruntime/transformers/models/whisper/benchmark_all.py +528 -0
  272. onnxruntime/transformers/models/whisper/convert_to_onnx.py +536 -0
  273. onnxruntime/transformers/models/whisper/whisper_chain.py +329 -0
  274. onnxruntime/transformers/models/whisper/whisper_decoder.py +402 -0
  275. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  276. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +306 -0
  277. onnxruntime/transformers/models/whisper/whisper_helper.py +524 -0
  278. onnxruntime/transformers/models/whisper/whisper_openai_helper.py +84 -0
  279. onnxruntime/transformers/onnx_exporter.py +717 -0
  280. onnxruntime/transformers/onnx_model.py +1569 -0
  281. onnxruntime/transformers/onnx_model_bart.py +142 -0
  282. onnxruntime/transformers/onnx_model_bert.py +481 -0
  283. onnxruntime/transformers/onnx_model_bert_keras.py +475 -0
  284. onnxruntime/transformers/onnx_model_bert_tf.py +589 -0
  285. onnxruntime/transformers/onnx_model_clip.py +40 -0
  286. onnxruntime/transformers/onnx_model_conformer.py +33 -0
  287. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  288. onnxruntime/transformers/onnx_model_phi.py +930 -0
  289. onnxruntime/transformers/onnx_model_sam2.py +138 -0
  290. onnxruntime/transformers/onnx_model_t5.py +791 -0
  291. onnxruntime/transformers/onnx_model_tnlr.py +227 -0
  292. onnxruntime/transformers/onnx_model_unet.py +259 -0
  293. onnxruntime/transformers/onnx_model_vae.py +43 -0
  294. onnxruntime/transformers/onnx_utils.py +55 -0
  295. onnxruntime/transformers/optimizer.py +612 -0
  296. onnxruntime/transformers/profiler.py +725 -0
  297. onnxruntime/transformers/quantize_helper.py +76 -0
  298. onnxruntime/transformers/shape_infer_helper.py +122 -0
  299. onnxruntime/transformers/shape_optimizer.py +401 -0
  300. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  301. onnxruntime_directml-1.20.0.dist-info/METADATA +187 -0
  302. onnxruntime_directml-1.20.0.dist-info/RECORD +305 -0
  303. onnxruntime_directml-1.20.0.dist-info/WHEEL +5 -0
  304. onnxruntime_directml-1.20.0.dist-info/entry_points.txt +2 -0
  305. onnxruntime_directml-1.20.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,68 @@
1
+ # automatically generated by the FlatBuffers compiler, do not modify
2
+
3
+ # namespace: fbs
4
+
5
+ import flatbuffers
6
+ from flatbuffers.compat import import_numpy
7
+ np = import_numpy()
8
+
9
+ # deprecated: no longer using kernel def hashes
10
+ class DeprecatedNodeIndexAndKernelDefHash(object):
11
+ __slots__ = ['_tab']
12
+
13
+ @classmethod
14
+ def GetRootAs(cls, buf, offset=0):
15
+ n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
16
+ x = DeprecatedNodeIndexAndKernelDefHash()
17
+ x.Init(buf, n + offset)
18
+ return x
19
+
20
+ @classmethod
21
+ def GetRootAsDeprecatedNodeIndexAndKernelDefHash(cls, buf, offset=0):
22
+ """This method is deprecated. Please switch to GetRootAs."""
23
+ return cls.GetRootAs(buf, offset)
24
+ @classmethod
25
+ def DeprecatedNodeIndexAndKernelDefHashBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
26
+ return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x4F\x52\x54\x4D", size_prefixed=size_prefixed)
27
+
28
+ # DeprecatedNodeIndexAndKernelDefHash
29
+ def Init(self, buf, pos):
30
+ self._tab = flatbuffers.table.Table(buf, pos)
31
+
32
+ # DeprecatedNodeIndexAndKernelDefHash
33
+ def NodeIndex(self):
34
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
35
+ if o != 0:
36
+ return self._tab.Get(flatbuffers.number_types.Uint32Flags, o + self._tab.Pos)
37
+ return 0
38
+
39
+ # DeprecatedNodeIndexAndKernelDefHash
40
+ def KernelDefHash(self):
41
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
42
+ if o != 0:
43
+ return self._tab.Get(flatbuffers.number_types.Uint64Flags, o + self._tab.Pos)
44
+ return 0
45
+
46
+ def DeprecatedNodeIndexAndKernelDefHashStart(builder):
47
+ builder.StartObject(2)
48
+
49
+ def Start(builder):
50
+ DeprecatedNodeIndexAndKernelDefHashStart(builder)
51
+
52
+ def DeprecatedNodeIndexAndKernelDefHashAddNodeIndex(builder, nodeIndex):
53
+ builder.PrependUint32Slot(0, nodeIndex, 0)
54
+
55
+ def AddNodeIndex(builder, nodeIndex):
56
+ DeprecatedNodeIndexAndKernelDefHashAddNodeIndex(builder, nodeIndex)
57
+
58
+ def DeprecatedNodeIndexAndKernelDefHashAddKernelDefHash(builder, kernelDefHash):
59
+ builder.PrependUint64Slot(1, kernelDefHash, 0)
60
+
61
+ def AddKernelDefHash(builder, kernelDefHash):
62
+ DeprecatedNodeIndexAndKernelDefHashAddKernelDefHash(builder, kernelDefHash)
63
+
64
+ def DeprecatedNodeIndexAndKernelDefHashEnd(builder):
65
+ return builder.EndObject()
66
+
67
+ def End(builder):
68
+ return DeprecatedNodeIndexAndKernelDefHashEnd(builder)
@@ -0,0 +1,96 @@
1
+ # automatically generated by the FlatBuffers compiler, do not modify
2
+
3
+ # namespace: fbs
4
+
5
+ import flatbuffers
6
+ from flatbuffers.compat import import_numpy
7
+ np = import_numpy()
8
+
9
+ # deprecated: no longer using kernel def hashes
10
+ class DeprecatedSessionState(object):
11
+ __slots__ = ['_tab']
12
+
13
+ @classmethod
14
+ def GetRootAs(cls, buf, offset=0):
15
+ n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
16
+ x = DeprecatedSessionState()
17
+ x.Init(buf, n + offset)
18
+ return x
19
+
20
+ @classmethod
21
+ def GetRootAsDeprecatedSessionState(cls, buf, offset=0):
22
+ """This method is deprecated. Please switch to GetRootAs."""
23
+ return cls.GetRootAs(buf, offset)
24
+ @classmethod
25
+ def DeprecatedSessionStateBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
26
+ return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x4F\x52\x54\x4D", size_prefixed=size_prefixed)
27
+
28
+ # DeprecatedSessionState
29
+ def Init(self, buf, pos):
30
+ self._tab = flatbuffers.table.Table(buf, pos)
31
+
32
+ # DeprecatedSessionState
33
+ def Kernels(self):
34
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
35
+ if o != 0:
36
+ x = self._tab.Indirect(o + self._tab.Pos)
37
+ from ort_flatbuffers_py.fbs.DeprecatedKernelCreateInfos import DeprecatedKernelCreateInfos
38
+ obj = DeprecatedKernelCreateInfos()
39
+ obj.Init(self._tab.Bytes, x)
40
+ return obj
41
+ return None
42
+
43
+ # DeprecatedSessionState
44
+ def SubGraphSessionStates(self, j):
45
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
46
+ if o != 0:
47
+ x = self._tab.Vector(o)
48
+ x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * 4
49
+ x = self._tab.Indirect(x)
50
+ from ort_flatbuffers_py.fbs.DeprecatedSubGraphSessionState import DeprecatedSubGraphSessionState
51
+ obj = DeprecatedSubGraphSessionState()
52
+ obj.Init(self._tab.Bytes, x)
53
+ return obj
54
+ return None
55
+
56
+ # DeprecatedSessionState
57
+ def SubGraphSessionStatesLength(self):
58
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
59
+ if o != 0:
60
+ return self._tab.VectorLen(o)
61
+ return 0
62
+
63
+ # DeprecatedSessionState
64
+ def SubGraphSessionStatesIsNone(self):
65
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
66
+ return o == 0
67
+
68
+ def DeprecatedSessionStateStart(builder):
69
+ builder.StartObject(2)
70
+
71
+ def Start(builder):
72
+ DeprecatedSessionStateStart(builder)
73
+
74
+ def DeprecatedSessionStateAddKernels(builder, kernels):
75
+ builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(kernels), 0)
76
+
77
+ def AddKernels(builder, kernels):
78
+ DeprecatedSessionStateAddKernels(builder, kernels)
79
+
80
+ def DeprecatedSessionStateAddSubGraphSessionStates(builder, subGraphSessionStates):
81
+ builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(subGraphSessionStates), 0)
82
+
83
+ def AddSubGraphSessionStates(builder, subGraphSessionStates):
84
+ DeprecatedSessionStateAddSubGraphSessionStates(builder, subGraphSessionStates)
85
+
86
+ def DeprecatedSessionStateStartSubGraphSessionStatesVector(builder, numElems):
87
+ return builder.StartVector(4, numElems, 4)
88
+
89
+ def StartSubGraphSessionStatesVector(builder, numElems: int) -> int:
90
+ return DeprecatedSessionStateStartSubGraphSessionStatesVector(builder, numElems)
91
+
92
+ def DeprecatedSessionStateEnd(builder):
93
+ return builder.EndObject()
94
+
95
+ def End(builder):
96
+ return DeprecatedSessionStateEnd(builder)
@@ -0,0 +1,72 @@
1
+ # automatically generated by the FlatBuffers compiler, do not modify
2
+
3
+ # namespace: fbs
4
+
5
+ import flatbuffers
6
+ from flatbuffers.compat import import_numpy
7
+ np = import_numpy()
8
+
9
+ # deprecated: no longer using kernel def hashes
10
+ class DeprecatedSubGraphSessionState(object):
11
+ __slots__ = ['_tab']
12
+
13
+ @classmethod
14
+ def GetRootAs(cls, buf, offset=0):
15
+ n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
16
+ x = DeprecatedSubGraphSessionState()
17
+ x.Init(buf, n + offset)
18
+ return x
19
+
20
+ @classmethod
21
+ def GetRootAsDeprecatedSubGraphSessionState(cls, buf, offset=0):
22
+ """This method is deprecated. Please switch to GetRootAs."""
23
+ return cls.GetRootAs(buf, offset)
24
+ @classmethod
25
+ def DeprecatedSubGraphSessionStateBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
26
+ return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x4F\x52\x54\x4D", size_prefixed=size_prefixed)
27
+
28
+ # DeprecatedSubGraphSessionState
29
+ def Init(self, buf, pos):
30
+ self._tab = flatbuffers.table.Table(buf, pos)
31
+
32
+ # DeprecatedSubGraphSessionState
33
+ def GraphId(self):
34
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
35
+ if o != 0:
36
+ return self._tab.String(o + self._tab.Pos)
37
+ return None
38
+
39
+ # DeprecatedSubGraphSessionState
40
+ def SessionState(self):
41
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
42
+ if o != 0:
43
+ x = self._tab.Indirect(o + self._tab.Pos)
44
+ from ort_flatbuffers_py.fbs.DeprecatedSessionState import DeprecatedSessionState
45
+ obj = DeprecatedSessionState()
46
+ obj.Init(self._tab.Bytes, x)
47
+ return obj
48
+ return None
49
+
50
+ def DeprecatedSubGraphSessionStateStart(builder):
51
+ builder.StartObject(2)
52
+
53
+ def Start(builder):
54
+ DeprecatedSubGraphSessionStateStart(builder)
55
+
56
+ def DeprecatedSubGraphSessionStateAddGraphId(builder, graphId):
57
+ builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(graphId), 0)
58
+
59
+ def AddGraphId(builder, graphId):
60
+ DeprecatedSubGraphSessionStateAddGraphId(builder, graphId)
61
+
62
+ def DeprecatedSubGraphSessionStateAddSessionState(builder, sessionState):
63
+ builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(sessionState), 0)
64
+
65
+ def AddSessionState(builder, sessionState):
66
+ DeprecatedSubGraphSessionStateAddSessionState(builder, sessionState)
67
+
68
+ def DeprecatedSubGraphSessionStateEnd(builder):
69
+ return builder.EndObject()
70
+
71
+ def End(builder):
72
+ return DeprecatedSubGraphSessionStateEnd(builder)
@@ -0,0 +1,71 @@
1
+ # automatically generated by the FlatBuffers compiler, do not modify
2
+
3
+ # namespace: fbs
4
+
5
+ import flatbuffers
6
+ from flatbuffers.compat import import_numpy
7
+ np = import_numpy()
8
+
9
+ class Dimension(object):
10
+ __slots__ = ['_tab']
11
+
12
+ @classmethod
13
+ def GetRootAs(cls, buf, offset=0):
14
+ n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
15
+ x = Dimension()
16
+ x.Init(buf, n + offset)
17
+ return x
18
+
19
+ @classmethod
20
+ def GetRootAsDimension(cls, buf, offset=0):
21
+ """This method is deprecated. Please switch to GetRootAs."""
22
+ return cls.GetRootAs(buf, offset)
23
+ @classmethod
24
+ def DimensionBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
25
+ return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x4F\x52\x54\x4D", size_prefixed=size_prefixed)
26
+
27
+ # Dimension
28
+ def Init(self, buf, pos):
29
+ self._tab = flatbuffers.table.Table(buf, pos)
30
+
31
+ # Dimension
32
+ def Value(self):
33
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
34
+ if o != 0:
35
+ x = self._tab.Indirect(o + self._tab.Pos)
36
+ from ort_flatbuffers_py.fbs.DimensionValue import DimensionValue
37
+ obj = DimensionValue()
38
+ obj.Init(self._tab.Bytes, x)
39
+ return obj
40
+ return None
41
+
42
+ # Dimension
43
+ def Denotation(self):
44
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
45
+ if o != 0:
46
+ return self._tab.String(o + self._tab.Pos)
47
+ return None
48
+
49
+ def DimensionStart(builder):
50
+ builder.StartObject(2)
51
+
52
+ def Start(builder):
53
+ DimensionStart(builder)
54
+
55
+ def DimensionAddValue(builder, value):
56
+ builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(value), 0)
57
+
58
+ def AddValue(builder, value):
59
+ DimensionAddValue(builder, value)
60
+
61
+ def DimensionAddDenotation(builder, denotation):
62
+ builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(denotation), 0)
63
+
64
+ def AddDenotation(builder, denotation):
65
+ DimensionAddDenotation(builder, denotation)
66
+
67
+ def DimensionEnd(builder):
68
+ return builder.EndObject()
69
+
70
+ def End(builder):
71
+ return DimensionEnd(builder)
@@ -0,0 +1,80 @@
1
+ # automatically generated by the FlatBuffers compiler, do not modify
2
+
3
+ # namespace: fbs
4
+
5
+ import flatbuffers
6
+ from flatbuffers.compat import import_numpy
7
+ np = import_numpy()
8
+
9
+ class DimensionValue(object):
10
+ __slots__ = ['_tab']
11
+
12
+ @classmethod
13
+ def GetRootAs(cls, buf, offset=0):
14
+ n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
15
+ x = DimensionValue()
16
+ x.Init(buf, n + offset)
17
+ return x
18
+
19
+ @classmethod
20
+ def GetRootAsDimensionValue(cls, buf, offset=0):
21
+ """This method is deprecated. Please switch to GetRootAs."""
22
+ return cls.GetRootAs(buf, offset)
23
+ @classmethod
24
+ def DimensionValueBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
25
+ return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x4F\x52\x54\x4D", size_prefixed=size_prefixed)
26
+
27
+ # DimensionValue
28
+ def Init(self, buf, pos):
29
+ self._tab = flatbuffers.table.Table(buf, pos)
30
+
31
+ # DimensionValue
32
+ def DimType(self):
33
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
34
+ if o != 0:
35
+ return self._tab.Get(flatbuffers.number_types.Int8Flags, o + self._tab.Pos)
36
+ return 0
37
+
38
+ # DimensionValue
39
+ def DimValue(self):
40
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
41
+ if o != 0:
42
+ return self._tab.Get(flatbuffers.number_types.Int64Flags, o + self._tab.Pos)
43
+ return 0
44
+
45
+ # DimensionValue
46
+ def DimParam(self):
47
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
48
+ if o != 0:
49
+ return self._tab.String(o + self._tab.Pos)
50
+ return None
51
+
52
+ def DimensionValueStart(builder):
53
+ builder.StartObject(3)
54
+
55
+ def Start(builder):
56
+ DimensionValueStart(builder)
57
+
58
+ def DimensionValueAddDimType(builder, dimType):
59
+ builder.PrependInt8Slot(0, dimType, 0)
60
+
61
+ def AddDimType(builder, dimType):
62
+ DimensionValueAddDimType(builder, dimType)
63
+
64
+ def DimensionValueAddDimValue(builder, dimValue):
65
+ builder.PrependInt64Slot(1, dimValue, 0)
66
+
67
+ def AddDimValue(builder, dimValue):
68
+ DimensionValueAddDimValue(builder, dimValue)
69
+
70
+ def DimensionValueAddDimParam(builder, dimParam):
71
+ builder.PrependUOffsetTRelativeSlot(2, flatbuffers.number_types.UOffsetTFlags.py_type(dimParam), 0)
72
+
73
+ def AddDimParam(builder, dimParam):
74
+ DimensionValueAddDimParam(builder, dimParam)
75
+
76
+ def DimensionValueEnd(builder):
77
+ return builder.EndObject()
78
+
79
+ def End(builder):
80
+ return DimensionValueEnd(builder)
@@ -0,0 +1,8 @@
1
+ # automatically generated by the FlatBuffers compiler, do not modify
2
+
3
+ # namespace: fbs
4
+
5
+ class DimensionValueType(object):
6
+ UNKNOWN = 0
7
+ VALUE = 1
8
+ PARAM = 2
@@ -0,0 +1,32 @@
1
+ # automatically generated by the FlatBuffers compiler, do not modify
2
+
3
+ # namespace: fbs
4
+
5
+ import flatbuffers
6
+ from flatbuffers.compat import import_numpy
7
+ np = import_numpy()
8
+
9
+ class EdgeEnd(object):
10
+ __slots__ = ['_tab']
11
+
12
+ @classmethod
13
+ def SizeOf(cls):
14
+ return 12
15
+
16
+ # EdgeEnd
17
+ def Init(self, buf, pos):
18
+ self._tab = flatbuffers.table.Table(buf, pos)
19
+
20
+ # EdgeEnd
21
+ def NodeIndex(self): return self._tab.Get(flatbuffers.number_types.Uint32Flags, self._tab.Pos + flatbuffers.number_types.UOffsetTFlags.py_type(0))
22
+ # EdgeEnd
23
+ def SrcArgIndex(self): return self._tab.Get(flatbuffers.number_types.Int32Flags, self._tab.Pos + flatbuffers.number_types.UOffsetTFlags.py_type(4))
24
+ # EdgeEnd
25
+ def DstArgIndex(self): return self._tab.Get(flatbuffers.number_types.Int32Flags, self._tab.Pos + flatbuffers.number_types.UOffsetTFlags.py_type(8))
26
+
27
+ def CreateEdgeEnd(builder, nodeIndex, srcArgIndex, dstArgIndex):
28
+ builder.Prep(4, 12)
29
+ builder.PrependInt32(dstArgIndex)
30
+ builder.PrependInt32(srcArgIndex)
31
+ builder.PrependUint32(nodeIndex)
32
+ return builder.Offset()
@@ -0,0 +1,67 @@
1
+ # automatically generated by the FlatBuffers compiler, do not modify
2
+
3
+ # namespace: fbs
4
+
5
+ import flatbuffers
6
+ from flatbuffers.compat import import_numpy
7
+ np = import_numpy()
8
+
9
+ class FloatProperty(object):
10
+ __slots__ = ['_tab']
11
+
12
+ @classmethod
13
+ def GetRootAs(cls, buf, offset=0):
14
+ n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
15
+ x = FloatProperty()
16
+ x.Init(buf, n + offset)
17
+ return x
18
+
19
+ @classmethod
20
+ def GetRootAsFloatProperty(cls, buf, offset=0):
21
+ """This method is deprecated. Please switch to GetRootAs."""
22
+ return cls.GetRootAs(buf, offset)
23
+ @classmethod
24
+ def FloatPropertyBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
25
+ return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x4F\x44\x54\x43", size_prefixed=size_prefixed)
26
+
27
+ # FloatProperty
28
+ def Init(self, buf, pos):
29
+ self._tab = flatbuffers.table.Table(buf, pos)
30
+
31
+ # FloatProperty
32
+ def Name(self):
33
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
34
+ if o != 0:
35
+ return self._tab.String(o + self._tab.Pos)
36
+ return None
37
+
38
+ # FloatProperty
39
+ def Value(self):
40
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
41
+ if o != 0:
42
+ return self._tab.Get(flatbuffers.number_types.Float32Flags, o + self._tab.Pos)
43
+ return 0.0
44
+
45
+ def FloatPropertyStart(builder):
46
+ builder.StartObject(2)
47
+
48
+ def Start(builder):
49
+ FloatPropertyStart(builder)
50
+
51
+ def FloatPropertyAddName(builder, name):
52
+ builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(name), 0)
53
+
54
+ def AddName(builder, name):
55
+ FloatPropertyAddName(builder, name)
56
+
57
+ def FloatPropertyAddValue(builder, value):
58
+ builder.PrependFloat32Slot(1, value, 0.0)
59
+
60
+ def AddValue(builder, value):
61
+ FloatPropertyAddValue(builder, value)
62
+
63
+ def FloatPropertyEnd(builder):
64
+ return builder.EndObject()
65
+
66
+ def End(builder):
67
+ return FloatPropertyEnd(builder)