onnxruntime-directml 1.20.0__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (305) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6508 -0
  4. onnxruntime/__init__.py +78 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +174 -0
  7. onnxruntime/backend/backend_rep.py +53 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  13. onnxruntime/capi/onnxruntime.dll +0 -0
  14. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  15. onnxruntime/capi/onnxruntime_inference_collection.py +1108 -0
  16. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  17. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  18. onnxruntime/capi/onnxruntime_validation.py +150 -0
  19. onnxruntime/capi/version_info.py +2 -0
  20. onnxruntime/datasets/__init__.py +17 -0
  21. onnxruntime/datasets/logreg_iris.onnx +0 -0
  22. onnxruntime/datasets/mul_1.onnx +0 -0
  23. onnxruntime/datasets/sigmoid.onnx +13 -0
  24. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  27. onnxruntime/quantization/__init__.py +16 -0
  28. onnxruntime/quantization/base_quantizer.py +532 -0
  29. onnxruntime/quantization/calibrate.py +1245 -0
  30. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  31. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  32. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  33. onnxruntime/quantization/execution_providers/qnn/preprocess.py +307 -0
  34. onnxruntime/quantization/execution_providers/qnn/quant_config.py +387 -0
  35. onnxruntime/quantization/fusions/__init__.py +3 -0
  36. onnxruntime/quantization/fusions/fusion.py +311 -0
  37. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  38. onnxruntime/quantization/fusions/fusion_layernorm.py +135 -0
  39. onnxruntime/quantization/matmul_4bits_quantizer.py +1480 -0
  40. onnxruntime/quantization/matmul_bnb4_quantizer.py +240 -0
  41. onnxruntime/quantization/onnx_model.py +580 -0
  42. onnxruntime/quantization/onnx_quantizer.py +1008 -0
  43. onnxruntime/quantization/operators/__init__.py +2 -0
  44. onnxruntime/quantization/operators/activation.py +119 -0
  45. onnxruntime/quantization/operators/argmax.py +18 -0
  46. onnxruntime/quantization/operators/attention.py +73 -0
  47. onnxruntime/quantization/operators/base_operator.py +26 -0
  48. onnxruntime/quantization/operators/binary_op.py +72 -0
  49. onnxruntime/quantization/operators/concat.py +62 -0
  50. onnxruntime/quantization/operators/conv.py +258 -0
  51. onnxruntime/quantization/operators/direct_q8.py +78 -0
  52. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  53. onnxruntime/quantization/operators/gather.py +64 -0
  54. onnxruntime/quantization/operators/gavgpool.py +62 -0
  55. onnxruntime/quantization/operators/gemm.py +166 -0
  56. onnxruntime/quantization/operators/lstm.py +117 -0
  57. onnxruntime/quantization/operators/matmul.py +231 -0
  58. onnxruntime/quantization/operators/maxpool.py +34 -0
  59. onnxruntime/quantization/operators/norm.py +40 -0
  60. onnxruntime/quantization/operators/pad.py +100 -0
  61. onnxruntime/quantization/operators/pooling.py +67 -0
  62. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  63. onnxruntime/quantization/operators/resize.py +34 -0
  64. onnxruntime/quantization/operators/softmax.py +74 -0
  65. onnxruntime/quantization/operators/split.py +63 -0
  66. onnxruntime/quantization/operators/where.py +87 -0
  67. onnxruntime/quantization/preprocess.py +141 -0
  68. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  69. onnxruntime/quantization/qdq_quantizer.py +1187 -0
  70. onnxruntime/quantization/quant_utils.py +891 -0
  71. onnxruntime/quantization/quantize.py +748 -0
  72. onnxruntime/quantization/registry.py +106 -0
  73. onnxruntime/quantization/shape_inference.py +187 -0
  74. onnxruntime/quantization/tensor_quant_overrides.py +516 -0
  75. onnxruntime/tools/__init__.py +10 -0
  76. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  77. onnxruntime/tools/convert_onnx_models_to_ort.py +377 -0
  78. onnxruntime/tools/file_utils.py +46 -0
  79. onnxruntime/tools/logger.py +11 -0
  80. onnxruntime/tools/make_dynamic_shape_fixed.py +72 -0
  81. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  82. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +33 -0
  83. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  84. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  85. onnxruntime/tools/mobile_helpers/usability_checker.py +739 -0
  86. onnxruntime/tools/offline_tuning.py +169 -0
  87. onnxruntime/tools/onnx_model_utils.py +413 -0
  88. onnxruntime/tools/onnx_randomizer.py +85 -0
  89. onnxruntime/tools/onnxruntime_test.py +164 -0
  90. onnxruntime/tools/optimize_onnx_model.py +55 -0
  91. onnxruntime/tools/ort_format_model/__init__.py +25 -0
  92. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +663 -0
  93. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  94. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  95. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  96. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  97. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  98. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  99. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  100. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  140. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  141. onnxruntime/tools/ort_format_model/types.py +84 -0
  142. onnxruntime/tools/ort_format_model/utils.py +62 -0
  143. onnxruntime/tools/pytorch_export_contrib_ops.py +108 -0
  144. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  145. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  146. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  147. onnxruntime/tools/reduced_build_config_parser.py +202 -0
  148. onnxruntime/tools/symbolic_shape_infer.py +3016 -0
  149. onnxruntime/tools/update_onnx_opset.py +31 -0
  150. onnxruntime/transformers/__init__.py +8 -0
  151. onnxruntime/transformers/affinity_helper.py +40 -0
  152. onnxruntime/transformers/benchmark.py +944 -0
  153. onnxruntime/transformers/benchmark_helper.py +646 -0
  154. onnxruntime/transformers/bert_perf_test.py +634 -0
  155. onnxruntime/transformers/bert_test_data.py +642 -0
  156. onnxruntime/transformers/compare_bert_results.py +246 -0
  157. onnxruntime/transformers/constants.py +47 -0
  158. onnxruntime/transformers/convert_generation.py +3124 -0
  159. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  160. onnxruntime/transformers/convert_to_packing_mode.py +387 -0
  161. onnxruntime/transformers/dynamo_onnx_helper.py +104 -0
  162. onnxruntime/transformers/float16.py +501 -0
  163. onnxruntime/transformers/fusion_attention.py +1235 -0
  164. onnxruntime/transformers/fusion_attention_clip.py +257 -0
  165. onnxruntime/transformers/fusion_attention_sam2.py +534 -0
  166. onnxruntime/transformers/fusion_attention_unet.py +1304 -0
  167. onnxruntime/transformers/fusion_attention_vae.py +301 -0
  168. onnxruntime/transformers/fusion_bart_attention.py +640 -0
  169. onnxruntime/transformers/fusion_base.py +137 -0
  170. onnxruntime/transformers/fusion_bias_add.py +58 -0
  171. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  172. onnxruntime/transformers/fusion_biassplitgelu.py +111 -0
  173. onnxruntime/transformers/fusion_conformer_attention.py +143 -0
  174. onnxruntime/transformers/fusion_embedlayer.py +811 -0
  175. onnxruntime/transformers/fusion_fastgelu.py +360 -0
  176. onnxruntime/transformers/fusion_gelu.py +259 -0
  177. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  178. onnxruntime/transformers/fusion_gemmfastgelu.py +122 -0
  179. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  180. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  181. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  182. onnxruntime/transformers/fusion_group_norm.py +179 -0
  183. onnxruntime/transformers/fusion_layernorm.py +465 -0
  184. onnxruntime/transformers/fusion_nhwc_conv.py +100 -0
  185. onnxruntime/transformers/fusion_options.py +340 -0
  186. onnxruntime/transformers/fusion_qordered_attention.py +421 -0
  187. onnxruntime/transformers/fusion_qordered_gelu.py +119 -0
  188. onnxruntime/transformers/fusion_qordered_layernorm.py +123 -0
  189. onnxruntime/transformers/fusion_qordered_matmul.py +217 -0
  190. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  191. onnxruntime/transformers/fusion_reshape.py +173 -0
  192. onnxruntime/transformers/fusion_rotary_attention.py +1592 -0
  193. onnxruntime/transformers/fusion_shape.py +110 -0
  194. onnxruntime/transformers/fusion_simplified_layernorm.py +159 -0
  195. onnxruntime/transformers/fusion_skip_group_norm.py +255 -0
  196. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  197. onnxruntime/transformers/fusion_transpose.py +168 -0
  198. onnxruntime/transformers/fusion_utils.py +307 -0
  199. onnxruntime/transformers/huggingface_models.py +167 -0
  200. onnxruntime/transformers/import_utils.py +20 -0
  201. onnxruntime/transformers/io_binding_helper.py +442 -0
  202. onnxruntime/transformers/large_model_exporter.py +395 -0
  203. onnxruntime/transformers/machine_info.py +221 -0
  204. onnxruntime/transformers/metrics.py +164 -0
  205. onnxruntime/transformers/models/bart/__init__.py +12 -0
  206. onnxruntime/transformers/models/bart/export.py +98 -0
  207. onnxruntime/transformers/models/bert/__init__.py +12 -0
  208. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  209. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  210. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  211. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +561 -0
  212. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1032 -0
  213. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  214. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  215. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  216. onnxruntime/transformers/models/llama/__init__.py +12 -0
  217. onnxruntime/transformers/models/llama/benchmark.py +703 -0
  218. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  219. onnxruntime/transformers/models/llama/benchmark_e2e.py +606 -0
  220. onnxruntime/transformers/models/llama/convert_to_onnx.py +1027 -0
  221. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  222. onnxruntime/transformers/models/llama/llama_inputs.py +503 -0
  223. onnxruntime/transformers/models/llama/llama_parity.py +309 -0
  224. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  225. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  226. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  227. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  228. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  229. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  230. onnxruntime/transformers/models/longformer/longformer_helper.py +77 -0
  231. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  232. onnxruntime/transformers/models/phi2/convert_to_onnx.py +576 -0
  233. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  234. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  235. onnxruntime/transformers/models/sam2/benchmark_sam2.py +625 -0
  236. onnxruntime/transformers/models/sam2/convert_to_onnx.py +260 -0
  237. onnxruntime/transformers/models/sam2/image_decoder.py +273 -0
  238. onnxruntime/transformers/models/sam2/image_encoder.py +186 -0
  239. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  240. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  241. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  242. onnxruntime/transformers/models/sam2/sam2_demo.py +322 -0
  243. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +280 -0
  244. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  245. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  246. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1429 -0
  247. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  248. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +102 -0
  249. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +268 -0
  250. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  251. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1319 -0
  252. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1181 -0
  253. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +296 -0
  254. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +388 -0
  255. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  256. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  257. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  258. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +350 -0
  259. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  260. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  261. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  262. onnxruntime/transformers/models/t5/__init__.py +12 -0
  263. onnxruntime/transformers/models/t5/convert_to_onnx.py +278 -0
  264. onnxruntime/transformers/models/t5/past_helper.py +150 -0
  265. onnxruntime/transformers/models/t5/t5_decoder.py +438 -0
  266. onnxruntime/transformers/models/t5/t5_encoder.py +171 -0
  267. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +299 -0
  268. onnxruntime/transformers/models/t5/t5_helper.py +272 -0
  269. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  270. onnxruntime/transformers/models/whisper/benchmark.py +610 -0
  271. onnxruntime/transformers/models/whisper/benchmark_all.py +528 -0
  272. onnxruntime/transformers/models/whisper/convert_to_onnx.py +536 -0
  273. onnxruntime/transformers/models/whisper/whisper_chain.py +329 -0
  274. onnxruntime/transformers/models/whisper/whisper_decoder.py +402 -0
  275. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  276. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +306 -0
  277. onnxruntime/transformers/models/whisper/whisper_helper.py +524 -0
  278. onnxruntime/transformers/models/whisper/whisper_openai_helper.py +84 -0
  279. onnxruntime/transformers/onnx_exporter.py +717 -0
  280. onnxruntime/transformers/onnx_model.py +1569 -0
  281. onnxruntime/transformers/onnx_model_bart.py +142 -0
  282. onnxruntime/transformers/onnx_model_bert.py +481 -0
  283. onnxruntime/transformers/onnx_model_bert_keras.py +475 -0
  284. onnxruntime/transformers/onnx_model_bert_tf.py +589 -0
  285. onnxruntime/transformers/onnx_model_clip.py +40 -0
  286. onnxruntime/transformers/onnx_model_conformer.py +33 -0
  287. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  288. onnxruntime/transformers/onnx_model_phi.py +930 -0
  289. onnxruntime/transformers/onnx_model_sam2.py +138 -0
  290. onnxruntime/transformers/onnx_model_t5.py +791 -0
  291. onnxruntime/transformers/onnx_model_tnlr.py +227 -0
  292. onnxruntime/transformers/onnx_model_unet.py +259 -0
  293. onnxruntime/transformers/onnx_model_vae.py +43 -0
  294. onnxruntime/transformers/onnx_utils.py +55 -0
  295. onnxruntime/transformers/optimizer.py +612 -0
  296. onnxruntime/transformers/profiler.py +725 -0
  297. onnxruntime/transformers/quantize_helper.py +76 -0
  298. onnxruntime/transformers/shape_infer_helper.py +122 -0
  299. onnxruntime/transformers/shape_optimizer.py +401 -0
  300. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  301. onnxruntime_directml-1.20.0.dist-info/METADATA +187 -0
  302. onnxruntime_directml-1.20.0.dist-info/RECORD +305 -0
  303. onnxruntime_directml-1.20.0.dist-info/WHEEL +5 -0
  304. onnxruntime_directml-1.20.0.dist-info/entry_points.txt +2 -0
  305. onnxruntime_directml-1.20.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,71 @@
1
+ # automatically generated by the FlatBuffers compiler, do not modify
2
+
3
+ # namespace: fbs
4
+
5
+ import flatbuffers
6
+ from flatbuffers.compat import import_numpy
7
+ np = import_numpy()
8
+
9
+ class MapType(object):
10
+ __slots__ = ['_tab']
11
+
12
+ @classmethod
13
+ def GetRootAs(cls, buf, offset=0):
14
+ n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
15
+ x = MapType()
16
+ x.Init(buf, n + offset)
17
+ return x
18
+
19
+ @classmethod
20
+ def GetRootAsMapType(cls, buf, offset=0):
21
+ """This method is deprecated. Please switch to GetRootAs."""
22
+ return cls.GetRootAs(buf, offset)
23
+ @classmethod
24
+ def MapTypeBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
25
+ return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x4F\x52\x54\x4D", size_prefixed=size_prefixed)
26
+
27
+ # MapType
28
+ def Init(self, buf, pos):
29
+ self._tab = flatbuffers.table.Table(buf, pos)
30
+
31
+ # MapType
32
+ def KeyType(self):
33
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
34
+ if o != 0:
35
+ return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos)
36
+ return 0
37
+
38
+ # MapType
39
+ def ValueType(self):
40
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
41
+ if o != 0:
42
+ x = self._tab.Indirect(o + self._tab.Pos)
43
+ from ort_flatbuffers_py.fbs.TypeInfo import TypeInfo
44
+ obj = TypeInfo()
45
+ obj.Init(self._tab.Bytes, x)
46
+ return obj
47
+ return None
48
+
49
+ def MapTypeStart(builder):
50
+ builder.StartObject(2)
51
+
52
+ def Start(builder):
53
+ MapTypeStart(builder)
54
+
55
+ def MapTypeAddKeyType(builder, keyType):
56
+ builder.PrependInt32Slot(0, keyType, 0)
57
+
58
+ def AddKeyType(builder, keyType):
59
+ MapTypeAddKeyType(builder, keyType)
60
+
61
+ def MapTypeAddValueType(builder, valueType):
62
+ builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(valueType), 0)
63
+
64
+ def AddValueType(builder, valueType):
65
+ MapTypeAddValueType(builder, valueType)
66
+
67
+ def MapTypeEnd(builder):
68
+ return builder.EndObject()
69
+
70
+ def End(builder):
71
+ return MapTypeEnd(builder)
@@ -0,0 +1,223 @@
1
+ # automatically generated by the FlatBuffers compiler, do not modify
2
+
3
+ # namespace: fbs
4
+
5
+ import flatbuffers
6
+ from flatbuffers.compat import import_numpy
7
+ np = import_numpy()
8
+
9
+ class Model(object):
10
+ __slots__ = ['_tab']
11
+
12
+ @classmethod
13
+ def GetRootAs(cls, buf, offset=0):
14
+ n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
15
+ x = Model()
16
+ x.Init(buf, n + offset)
17
+ return x
18
+
19
+ @classmethod
20
+ def GetRootAsModel(cls, buf, offset=0):
21
+ """This method is deprecated. Please switch to GetRootAs."""
22
+ return cls.GetRootAs(buf, offset)
23
+ @classmethod
24
+ def ModelBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
25
+ return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x4F\x52\x54\x4D", size_prefixed=size_prefixed)
26
+
27
+ # Model
28
+ def Init(self, buf, pos):
29
+ self._tab = flatbuffers.table.Table(buf, pos)
30
+
31
+ # Model
32
+ def IrVersion(self):
33
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
34
+ if o != 0:
35
+ return self._tab.Get(flatbuffers.number_types.Int64Flags, o + self._tab.Pos)
36
+ return 0
37
+
38
+ # Model
39
+ def OpsetImport(self, j):
40
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
41
+ if o != 0:
42
+ x = self._tab.Vector(o)
43
+ x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * 4
44
+ x = self._tab.Indirect(x)
45
+ from ort_flatbuffers_py.fbs.OperatorSetId import OperatorSetId
46
+ obj = OperatorSetId()
47
+ obj.Init(self._tab.Bytes, x)
48
+ return obj
49
+ return None
50
+
51
+ # Model
52
+ def OpsetImportLength(self):
53
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
54
+ if o != 0:
55
+ return self._tab.VectorLen(o)
56
+ return 0
57
+
58
+ # Model
59
+ def OpsetImportIsNone(self):
60
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
61
+ return o == 0
62
+
63
+ # Model
64
+ def ProducerName(self):
65
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
66
+ if o != 0:
67
+ return self._tab.String(o + self._tab.Pos)
68
+ return None
69
+
70
+ # Model
71
+ def ProducerVersion(self):
72
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10))
73
+ if o != 0:
74
+ return self._tab.String(o + self._tab.Pos)
75
+ return None
76
+
77
+ # Model
78
+ def Domain(self):
79
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12))
80
+ if o != 0:
81
+ return self._tab.String(o + self._tab.Pos)
82
+ return None
83
+
84
+ # Model
85
+ def ModelVersion(self):
86
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(14))
87
+ if o != 0:
88
+ return self._tab.Get(flatbuffers.number_types.Int64Flags, o + self._tab.Pos)
89
+ return 0
90
+
91
+ # Model
92
+ def DocString(self):
93
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(16))
94
+ if o != 0:
95
+ return self._tab.String(o + self._tab.Pos)
96
+ return None
97
+
98
+ # Model
99
+ def Graph(self):
100
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(18))
101
+ if o != 0:
102
+ x = self._tab.Indirect(o + self._tab.Pos)
103
+ from ort_flatbuffers_py.fbs.Graph import Graph
104
+ obj = Graph()
105
+ obj.Init(self._tab.Bytes, x)
106
+ return obj
107
+ return None
108
+
109
+ # Model
110
+ def GraphDocString(self):
111
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(20))
112
+ if o != 0:
113
+ return self._tab.String(o + self._tab.Pos)
114
+ return None
115
+
116
+ # Model
117
+ def MetadataProps(self, j):
118
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(22))
119
+ if o != 0:
120
+ x = self._tab.Vector(o)
121
+ x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * 4
122
+ x = self._tab.Indirect(x)
123
+ from ort_flatbuffers_py.fbs.StringStringEntry import StringStringEntry
124
+ obj = StringStringEntry()
125
+ obj.Init(self._tab.Bytes, x)
126
+ return obj
127
+ return None
128
+
129
+ # Model
130
+ def MetadataPropsLength(self):
131
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(22))
132
+ if o != 0:
133
+ return self._tab.VectorLen(o)
134
+ return 0
135
+
136
+ # Model
137
+ def MetadataPropsIsNone(self):
138
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(22))
139
+ return o == 0
140
+
141
+ def ModelStart(builder):
142
+ builder.StartObject(10)
143
+
144
+ def Start(builder):
145
+ ModelStart(builder)
146
+
147
+ def ModelAddIrVersion(builder, irVersion):
148
+ builder.PrependInt64Slot(0, irVersion, 0)
149
+
150
+ def AddIrVersion(builder, irVersion):
151
+ ModelAddIrVersion(builder, irVersion)
152
+
153
+ def ModelAddOpsetImport(builder, opsetImport):
154
+ builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(opsetImport), 0)
155
+
156
+ def AddOpsetImport(builder, opsetImport):
157
+ ModelAddOpsetImport(builder, opsetImport)
158
+
159
+ def ModelStartOpsetImportVector(builder, numElems):
160
+ return builder.StartVector(4, numElems, 4)
161
+
162
+ def StartOpsetImportVector(builder, numElems: int) -> int:
163
+ return ModelStartOpsetImportVector(builder, numElems)
164
+
165
+ def ModelAddProducerName(builder, producerName):
166
+ builder.PrependUOffsetTRelativeSlot(2, flatbuffers.number_types.UOffsetTFlags.py_type(producerName), 0)
167
+
168
+ def AddProducerName(builder, producerName):
169
+ ModelAddProducerName(builder, producerName)
170
+
171
+ def ModelAddProducerVersion(builder, producerVersion):
172
+ builder.PrependUOffsetTRelativeSlot(3, flatbuffers.number_types.UOffsetTFlags.py_type(producerVersion), 0)
173
+
174
+ def AddProducerVersion(builder, producerVersion):
175
+ ModelAddProducerVersion(builder, producerVersion)
176
+
177
+ def ModelAddDomain(builder, domain):
178
+ builder.PrependUOffsetTRelativeSlot(4, flatbuffers.number_types.UOffsetTFlags.py_type(domain), 0)
179
+
180
+ def AddDomain(builder, domain):
181
+ ModelAddDomain(builder, domain)
182
+
183
+ def ModelAddModelVersion(builder, modelVersion):
184
+ builder.PrependInt64Slot(5, modelVersion, 0)
185
+
186
+ def AddModelVersion(builder, modelVersion):
187
+ ModelAddModelVersion(builder, modelVersion)
188
+
189
+ def ModelAddDocString(builder, docString):
190
+ builder.PrependUOffsetTRelativeSlot(6, flatbuffers.number_types.UOffsetTFlags.py_type(docString), 0)
191
+
192
+ def AddDocString(builder, docString):
193
+ ModelAddDocString(builder, docString)
194
+
195
+ def ModelAddGraph(builder, graph):
196
+ builder.PrependUOffsetTRelativeSlot(7, flatbuffers.number_types.UOffsetTFlags.py_type(graph), 0)
197
+
198
+ def AddGraph(builder, graph):
199
+ ModelAddGraph(builder, graph)
200
+
201
+ def ModelAddGraphDocString(builder, graphDocString):
202
+ builder.PrependUOffsetTRelativeSlot(8, flatbuffers.number_types.UOffsetTFlags.py_type(graphDocString), 0)
203
+
204
+ def AddGraphDocString(builder, graphDocString):
205
+ ModelAddGraphDocString(builder, graphDocString)
206
+
207
+ def ModelAddMetadataProps(builder, metadataProps):
208
+ builder.PrependUOffsetTRelativeSlot(9, flatbuffers.number_types.UOffsetTFlags.py_type(metadataProps), 0)
209
+
210
+ def AddMetadataProps(builder, metadataProps):
211
+ ModelAddMetadataProps(builder, metadataProps)
212
+
213
+ def ModelStartMetadataPropsVector(builder, numElems):
214
+ return builder.StartVector(4, numElems, 4)
215
+
216
+ def StartMetadataPropsVector(builder, numElems: int) -> int:
217
+ return ModelStartMetadataPropsVector(builder, numElems)
218
+
219
+ def ModelEnd(builder):
220
+ return builder.EndObject()
221
+
222
+ def End(builder):
223
+ return ModelEnd(builder)
@@ -0,0 +1,141 @@
1
+ # automatically generated by the FlatBuffers compiler, do not modify
2
+
3
+ # namespace: fbs
4
+
5
+ import flatbuffers
6
+ from flatbuffers.compat import import_numpy
7
+ np = import_numpy()
8
+
9
+ class ModuleState(object):
10
+ __slots__ = ['_tab']
11
+
12
+ @classmethod
13
+ def GetRootAs(cls, buf, offset=0):
14
+ n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
15
+ x = ModuleState()
16
+ x.Init(buf, n + offset)
17
+ return x
18
+
19
+ @classmethod
20
+ def GetRootAsModuleState(cls, buf, offset=0):
21
+ """This method is deprecated. Please switch to GetRootAs."""
22
+ return cls.GetRootAs(buf, offset)
23
+ @classmethod
24
+ def ModuleStateBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
25
+ return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x4F\x44\x54\x43", size_prefixed=size_prefixed)
26
+
27
+ # ModuleState
28
+ def Init(self, buf, pos):
29
+ self._tab = flatbuffers.table.Table(buf, pos)
30
+
31
+ # ModuleState
32
+ def RequiresGradParams(self, j):
33
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
34
+ if o != 0:
35
+ x = self._tab.Vector(o)
36
+ x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * 4
37
+ x = self._tab.Indirect(x)
38
+ from ort_flatbuffers_py.fbs.Tensor import Tensor
39
+ obj = Tensor()
40
+ obj.Init(self._tab.Bytes, x)
41
+ return obj
42
+ return None
43
+
44
+ # ModuleState
45
+ def RequiresGradParamsLength(self):
46
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
47
+ if o != 0:
48
+ return self._tab.VectorLen(o)
49
+ return 0
50
+
51
+ # ModuleState
52
+ def RequiresGradParamsIsNone(self):
53
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
54
+ return o == 0
55
+
56
+ # ModuleState
57
+ def FrozenParams(self, j):
58
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
59
+ if o != 0:
60
+ x = self._tab.Vector(o)
61
+ x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * 4
62
+ x = self._tab.Indirect(x)
63
+ from ort_flatbuffers_py.fbs.Tensor import Tensor
64
+ obj = Tensor()
65
+ obj.Init(self._tab.Bytes, x)
66
+ return obj
67
+ return None
68
+
69
+ # ModuleState
70
+ def FrozenParamsLength(self):
71
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
72
+ if o != 0:
73
+ return self._tab.VectorLen(o)
74
+ return 0
75
+
76
+ # ModuleState
77
+ def FrozenParamsIsNone(self):
78
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
79
+ return o == 0
80
+
81
+ # ModuleState
82
+ def IsNominalState(self):
83
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
84
+ if o != 0:
85
+ return bool(self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos))
86
+ return False
87
+
88
+ # ModuleState
89
+ def HasExternalData(self):
90
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10))
91
+ if o != 0:
92
+ return bool(self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos))
93
+ return False
94
+
95
+ def ModuleStateStart(builder):
96
+ builder.StartObject(4)
97
+
98
+ def Start(builder):
99
+ ModuleStateStart(builder)
100
+
101
+ def ModuleStateAddRequiresGradParams(builder, requiresGradParams):
102
+ builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(requiresGradParams), 0)
103
+
104
+ def AddRequiresGradParams(builder, requiresGradParams):
105
+ ModuleStateAddRequiresGradParams(builder, requiresGradParams)
106
+
107
+ def ModuleStateStartRequiresGradParamsVector(builder, numElems):
108
+ return builder.StartVector(4, numElems, 4)
109
+
110
+ def StartRequiresGradParamsVector(builder, numElems: int) -> int:
111
+ return ModuleStateStartRequiresGradParamsVector(builder, numElems)
112
+
113
+ def ModuleStateAddFrozenParams(builder, frozenParams):
114
+ builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(frozenParams), 0)
115
+
116
+ def AddFrozenParams(builder, frozenParams):
117
+ ModuleStateAddFrozenParams(builder, frozenParams)
118
+
119
+ def ModuleStateStartFrozenParamsVector(builder, numElems):
120
+ return builder.StartVector(4, numElems, 4)
121
+
122
+ def StartFrozenParamsVector(builder, numElems: int) -> int:
123
+ return ModuleStateStartFrozenParamsVector(builder, numElems)
124
+
125
+ def ModuleStateAddIsNominalState(builder, isNominalState):
126
+ builder.PrependBoolSlot(2, isNominalState, 0)
127
+
128
+ def AddIsNominalState(builder, isNominalState):
129
+ ModuleStateAddIsNominalState(builder, isNominalState)
130
+
131
+ def ModuleStateAddHasExternalData(builder, hasExternalData):
132
+ builder.PrependBoolSlot(3, hasExternalData, 0)
133
+
134
+ def AddHasExternalData(builder, hasExternalData):
135
+ ModuleStateAddHasExternalData(builder, hasExternalData)
136
+
137
+ def ModuleStateEnd(builder):
138
+ return builder.EndObject()
139
+
140
+ def End(builder):
141
+ return ModuleStateEnd(builder)