onnxruntime-directml 1.20.0__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (305) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6508 -0
  4. onnxruntime/__init__.py +78 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +174 -0
  7. onnxruntime/backend/backend_rep.py +53 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  13. onnxruntime/capi/onnxruntime.dll +0 -0
  14. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  15. onnxruntime/capi/onnxruntime_inference_collection.py +1108 -0
  16. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  17. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  18. onnxruntime/capi/onnxruntime_validation.py +150 -0
  19. onnxruntime/capi/version_info.py +2 -0
  20. onnxruntime/datasets/__init__.py +17 -0
  21. onnxruntime/datasets/logreg_iris.onnx +0 -0
  22. onnxruntime/datasets/mul_1.onnx +0 -0
  23. onnxruntime/datasets/sigmoid.onnx +13 -0
  24. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  27. onnxruntime/quantization/__init__.py +16 -0
  28. onnxruntime/quantization/base_quantizer.py +532 -0
  29. onnxruntime/quantization/calibrate.py +1245 -0
  30. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  31. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  32. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  33. onnxruntime/quantization/execution_providers/qnn/preprocess.py +307 -0
  34. onnxruntime/quantization/execution_providers/qnn/quant_config.py +387 -0
  35. onnxruntime/quantization/fusions/__init__.py +3 -0
  36. onnxruntime/quantization/fusions/fusion.py +311 -0
  37. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  38. onnxruntime/quantization/fusions/fusion_layernorm.py +135 -0
  39. onnxruntime/quantization/matmul_4bits_quantizer.py +1480 -0
  40. onnxruntime/quantization/matmul_bnb4_quantizer.py +240 -0
  41. onnxruntime/quantization/onnx_model.py +580 -0
  42. onnxruntime/quantization/onnx_quantizer.py +1008 -0
  43. onnxruntime/quantization/operators/__init__.py +2 -0
  44. onnxruntime/quantization/operators/activation.py +119 -0
  45. onnxruntime/quantization/operators/argmax.py +18 -0
  46. onnxruntime/quantization/operators/attention.py +73 -0
  47. onnxruntime/quantization/operators/base_operator.py +26 -0
  48. onnxruntime/quantization/operators/binary_op.py +72 -0
  49. onnxruntime/quantization/operators/concat.py +62 -0
  50. onnxruntime/quantization/operators/conv.py +258 -0
  51. onnxruntime/quantization/operators/direct_q8.py +78 -0
  52. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  53. onnxruntime/quantization/operators/gather.py +64 -0
  54. onnxruntime/quantization/operators/gavgpool.py +62 -0
  55. onnxruntime/quantization/operators/gemm.py +166 -0
  56. onnxruntime/quantization/operators/lstm.py +117 -0
  57. onnxruntime/quantization/operators/matmul.py +231 -0
  58. onnxruntime/quantization/operators/maxpool.py +34 -0
  59. onnxruntime/quantization/operators/norm.py +40 -0
  60. onnxruntime/quantization/operators/pad.py +100 -0
  61. onnxruntime/quantization/operators/pooling.py +67 -0
  62. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  63. onnxruntime/quantization/operators/resize.py +34 -0
  64. onnxruntime/quantization/operators/softmax.py +74 -0
  65. onnxruntime/quantization/operators/split.py +63 -0
  66. onnxruntime/quantization/operators/where.py +87 -0
  67. onnxruntime/quantization/preprocess.py +141 -0
  68. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  69. onnxruntime/quantization/qdq_quantizer.py +1187 -0
  70. onnxruntime/quantization/quant_utils.py +891 -0
  71. onnxruntime/quantization/quantize.py +748 -0
  72. onnxruntime/quantization/registry.py +106 -0
  73. onnxruntime/quantization/shape_inference.py +187 -0
  74. onnxruntime/quantization/tensor_quant_overrides.py +516 -0
  75. onnxruntime/tools/__init__.py +10 -0
  76. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  77. onnxruntime/tools/convert_onnx_models_to_ort.py +377 -0
  78. onnxruntime/tools/file_utils.py +46 -0
  79. onnxruntime/tools/logger.py +11 -0
  80. onnxruntime/tools/make_dynamic_shape_fixed.py +72 -0
  81. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  82. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +33 -0
  83. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  84. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  85. onnxruntime/tools/mobile_helpers/usability_checker.py +739 -0
  86. onnxruntime/tools/offline_tuning.py +169 -0
  87. onnxruntime/tools/onnx_model_utils.py +413 -0
  88. onnxruntime/tools/onnx_randomizer.py +85 -0
  89. onnxruntime/tools/onnxruntime_test.py +164 -0
  90. onnxruntime/tools/optimize_onnx_model.py +55 -0
  91. onnxruntime/tools/ort_format_model/__init__.py +25 -0
  92. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +663 -0
  93. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  94. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  95. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  96. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  97. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  98. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  99. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  100. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  140. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  141. onnxruntime/tools/ort_format_model/types.py +84 -0
  142. onnxruntime/tools/ort_format_model/utils.py +62 -0
  143. onnxruntime/tools/pytorch_export_contrib_ops.py +108 -0
  144. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  145. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  146. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  147. onnxruntime/tools/reduced_build_config_parser.py +202 -0
  148. onnxruntime/tools/symbolic_shape_infer.py +3016 -0
  149. onnxruntime/tools/update_onnx_opset.py +31 -0
  150. onnxruntime/transformers/__init__.py +8 -0
  151. onnxruntime/transformers/affinity_helper.py +40 -0
  152. onnxruntime/transformers/benchmark.py +944 -0
  153. onnxruntime/transformers/benchmark_helper.py +646 -0
  154. onnxruntime/transformers/bert_perf_test.py +634 -0
  155. onnxruntime/transformers/bert_test_data.py +642 -0
  156. onnxruntime/transformers/compare_bert_results.py +246 -0
  157. onnxruntime/transformers/constants.py +47 -0
  158. onnxruntime/transformers/convert_generation.py +3124 -0
  159. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  160. onnxruntime/transformers/convert_to_packing_mode.py +387 -0
  161. onnxruntime/transformers/dynamo_onnx_helper.py +104 -0
  162. onnxruntime/transformers/float16.py +501 -0
  163. onnxruntime/transformers/fusion_attention.py +1235 -0
  164. onnxruntime/transformers/fusion_attention_clip.py +257 -0
  165. onnxruntime/transformers/fusion_attention_sam2.py +534 -0
  166. onnxruntime/transformers/fusion_attention_unet.py +1304 -0
  167. onnxruntime/transformers/fusion_attention_vae.py +301 -0
  168. onnxruntime/transformers/fusion_bart_attention.py +640 -0
  169. onnxruntime/transformers/fusion_base.py +137 -0
  170. onnxruntime/transformers/fusion_bias_add.py +58 -0
  171. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  172. onnxruntime/transformers/fusion_biassplitgelu.py +111 -0
  173. onnxruntime/transformers/fusion_conformer_attention.py +143 -0
  174. onnxruntime/transformers/fusion_embedlayer.py +811 -0
  175. onnxruntime/transformers/fusion_fastgelu.py +360 -0
  176. onnxruntime/transformers/fusion_gelu.py +259 -0
  177. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  178. onnxruntime/transformers/fusion_gemmfastgelu.py +122 -0
  179. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  180. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  181. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  182. onnxruntime/transformers/fusion_group_norm.py +179 -0
  183. onnxruntime/transformers/fusion_layernorm.py +465 -0
  184. onnxruntime/transformers/fusion_nhwc_conv.py +100 -0
  185. onnxruntime/transformers/fusion_options.py +340 -0
  186. onnxruntime/transformers/fusion_qordered_attention.py +421 -0
  187. onnxruntime/transformers/fusion_qordered_gelu.py +119 -0
  188. onnxruntime/transformers/fusion_qordered_layernorm.py +123 -0
  189. onnxruntime/transformers/fusion_qordered_matmul.py +217 -0
  190. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  191. onnxruntime/transformers/fusion_reshape.py +173 -0
  192. onnxruntime/transformers/fusion_rotary_attention.py +1592 -0
  193. onnxruntime/transformers/fusion_shape.py +110 -0
  194. onnxruntime/transformers/fusion_simplified_layernorm.py +159 -0
  195. onnxruntime/transformers/fusion_skip_group_norm.py +255 -0
  196. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  197. onnxruntime/transformers/fusion_transpose.py +168 -0
  198. onnxruntime/transformers/fusion_utils.py +307 -0
  199. onnxruntime/transformers/huggingface_models.py +167 -0
  200. onnxruntime/transformers/import_utils.py +20 -0
  201. onnxruntime/transformers/io_binding_helper.py +442 -0
  202. onnxruntime/transformers/large_model_exporter.py +395 -0
  203. onnxruntime/transformers/machine_info.py +221 -0
  204. onnxruntime/transformers/metrics.py +164 -0
  205. onnxruntime/transformers/models/bart/__init__.py +12 -0
  206. onnxruntime/transformers/models/bart/export.py +98 -0
  207. onnxruntime/transformers/models/bert/__init__.py +12 -0
  208. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  209. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  210. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  211. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +561 -0
  212. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1032 -0
  213. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  214. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  215. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  216. onnxruntime/transformers/models/llama/__init__.py +12 -0
  217. onnxruntime/transformers/models/llama/benchmark.py +703 -0
  218. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  219. onnxruntime/transformers/models/llama/benchmark_e2e.py +606 -0
  220. onnxruntime/transformers/models/llama/convert_to_onnx.py +1027 -0
  221. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  222. onnxruntime/transformers/models/llama/llama_inputs.py +503 -0
  223. onnxruntime/transformers/models/llama/llama_parity.py +309 -0
  224. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  225. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  226. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  227. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  228. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  229. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  230. onnxruntime/transformers/models/longformer/longformer_helper.py +77 -0
  231. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  232. onnxruntime/transformers/models/phi2/convert_to_onnx.py +576 -0
  233. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  234. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  235. onnxruntime/transformers/models/sam2/benchmark_sam2.py +625 -0
  236. onnxruntime/transformers/models/sam2/convert_to_onnx.py +260 -0
  237. onnxruntime/transformers/models/sam2/image_decoder.py +273 -0
  238. onnxruntime/transformers/models/sam2/image_encoder.py +186 -0
  239. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  240. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  241. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  242. onnxruntime/transformers/models/sam2/sam2_demo.py +322 -0
  243. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +280 -0
  244. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  245. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  246. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1429 -0
  247. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  248. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +102 -0
  249. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +268 -0
  250. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  251. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1319 -0
  252. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1181 -0
  253. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +296 -0
  254. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +388 -0
  255. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  256. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  257. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  258. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +350 -0
  259. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  260. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  261. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  262. onnxruntime/transformers/models/t5/__init__.py +12 -0
  263. onnxruntime/transformers/models/t5/convert_to_onnx.py +278 -0
  264. onnxruntime/transformers/models/t5/past_helper.py +150 -0
  265. onnxruntime/transformers/models/t5/t5_decoder.py +438 -0
  266. onnxruntime/transformers/models/t5/t5_encoder.py +171 -0
  267. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +299 -0
  268. onnxruntime/transformers/models/t5/t5_helper.py +272 -0
  269. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  270. onnxruntime/transformers/models/whisper/benchmark.py +610 -0
  271. onnxruntime/transformers/models/whisper/benchmark_all.py +528 -0
  272. onnxruntime/transformers/models/whisper/convert_to_onnx.py +536 -0
  273. onnxruntime/transformers/models/whisper/whisper_chain.py +329 -0
  274. onnxruntime/transformers/models/whisper/whisper_decoder.py +402 -0
  275. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  276. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +306 -0
  277. onnxruntime/transformers/models/whisper/whisper_helper.py +524 -0
  278. onnxruntime/transformers/models/whisper/whisper_openai_helper.py +84 -0
  279. onnxruntime/transformers/onnx_exporter.py +717 -0
  280. onnxruntime/transformers/onnx_model.py +1569 -0
  281. onnxruntime/transformers/onnx_model_bart.py +142 -0
  282. onnxruntime/transformers/onnx_model_bert.py +481 -0
  283. onnxruntime/transformers/onnx_model_bert_keras.py +475 -0
  284. onnxruntime/transformers/onnx_model_bert_tf.py +589 -0
  285. onnxruntime/transformers/onnx_model_clip.py +40 -0
  286. onnxruntime/transformers/onnx_model_conformer.py +33 -0
  287. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  288. onnxruntime/transformers/onnx_model_phi.py +930 -0
  289. onnxruntime/transformers/onnx_model_sam2.py +138 -0
  290. onnxruntime/transformers/onnx_model_t5.py +791 -0
  291. onnxruntime/transformers/onnx_model_tnlr.py +227 -0
  292. onnxruntime/transformers/onnx_model_unet.py +259 -0
  293. onnxruntime/transformers/onnx_model_vae.py +43 -0
  294. onnxruntime/transformers/onnx_utils.py +55 -0
  295. onnxruntime/transformers/optimizer.py +612 -0
  296. onnxruntime/transformers/profiler.py +725 -0
  297. onnxruntime/transformers/quantize_helper.py +76 -0
  298. onnxruntime/transformers/shape_infer_helper.py +122 -0
  299. onnxruntime/transformers/shape_optimizer.py +401 -0
  300. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  301. onnxruntime_directml-1.20.0.dist-info/METADATA +187 -0
  302. onnxruntime_directml-1.20.0.dist-info/RECORD +305 -0
  303. onnxruntime_directml-1.20.0.dist-info/WHEEL +5 -0
  304. onnxruntime_directml-1.20.0.dist-info/entry_points.txt +2 -0
  305. onnxruntime_directml-1.20.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,131 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ import inspect
5
+ from collections import abc
6
+
7
+ import torch
8
+
9
+
10
+ def _parse_inputs_for_onnx_export(all_input_parameters, inputs, kwargs):
11
+ # extracted from https://github.com/microsoft/onnxruntime/blob/239c6ad3f021ff7cc2e6247eb074bd4208dc11e2/orttraining/orttraining/python/training/ortmodule/_io.py#L433
12
+
13
+ def _add_input(name, input):
14
+ """Returns number of expanded inputs that _add_input processed"""
15
+
16
+ if input is None:
17
+ # Drop all None inputs and return 0.
18
+ return 0
19
+
20
+ num_expanded_non_none_inputs = 0
21
+ if isinstance(input, abc.Sequence):
22
+ # If the input is a sequence (like a list), expand the list so that
23
+ # each element of the list is an input by itself.
24
+ for i, val in enumerate(input):
25
+ # Name each input with the index appended to the original name of the
26
+ # argument.
27
+ num_expanded_non_none_inputs += _add_input(f"{name}_{i}", val)
28
+
29
+ # Return here since the list by itself is not a valid input.
30
+ # All the elements of the list have already been added as inputs individually.
31
+ return num_expanded_non_none_inputs
32
+ elif isinstance(input, abc.Mapping):
33
+ # If the input is a mapping (like a dict), expand the dict so that
34
+ # each element of the dict is an input by itself.
35
+ for key, val in input.items():
36
+ num_expanded_non_none_inputs += _add_input(f"{name}_{key}", val)
37
+
38
+ # Return here since the dict by itself is not a valid input.
39
+ # All the elements of the dict have already been added as inputs individually.
40
+ return num_expanded_non_none_inputs
41
+
42
+ # InputInfo should contain all the names irrespective of whether they are
43
+ # a part of the onnx graph or not.
44
+ input_names.append(name)
45
+
46
+ # A single input non none input was processed, return 1
47
+ return 1
48
+
49
+ input_names = []
50
+ var_positional_idx = 0
51
+ num_expanded_non_none_positional_inputs = 0
52
+
53
+ for input_idx, input_parameter in enumerate(all_input_parameters):
54
+ if input_parameter.kind == inspect.Parameter.VAR_POSITIONAL:
55
+ # VAR_POSITIONAL parameter carries all *args parameters from original forward method
56
+ for args_i in range(input_idx, len(inputs)):
57
+ name = f"{input_parameter.name}_{var_positional_idx}"
58
+ var_positional_idx += 1
59
+ inp = inputs[args_i]
60
+ num_expanded_non_none_positional_inputs += _add_input(name, inp)
61
+ elif (
62
+ input_parameter.kind == inspect.Parameter.POSITIONAL_ONLY
63
+ or input_parameter.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
64
+ or input_parameter.kind == inspect.Parameter.KEYWORD_ONLY
65
+ ):
66
+ # All positional non-*args and non-**kwargs are processed here
67
+ name = input_parameter.name
68
+ inp = None
69
+ input_idx += var_positional_idx # noqa: PLW2901
70
+ is_positional = True
71
+ if input_idx < len(inputs) and inputs[input_idx] is not None:
72
+ inp = inputs[input_idx]
73
+ elif name in kwargs and kwargs[name] is not None:
74
+ inp = kwargs[name]
75
+ is_positional = False
76
+ num_expanded_non_none_inputs_local = _add_input(name, inp)
77
+ if is_positional:
78
+ num_expanded_non_none_positional_inputs += num_expanded_non_none_inputs_local
79
+ elif input_parameter.kind == inspect.Parameter.VAR_KEYWORD:
80
+ # **kwargs is always the last argument of forward()
81
+ for name, inp in kwargs.items():
82
+ if name not in input_names:
83
+ _add_input(name, inp)
84
+
85
+ return input_names
86
+
87
+
88
+ def _flatten_module_input(names, args, kwargs):
89
+ """Flatten args and kwargs in a single tuple of tensors."""
90
+ # extracted from https://github.com/microsoft/onnxruntime/blob/239c6ad3f021ff7cc2e6247eb074bd4208dc11e2/orttraining/orttraining/python/training/ortmodule/_io.py#L110
91
+
92
+ def is_primitive_type(value):
93
+ return type(value) in {int, bool, float}
94
+
95
+ def to_tensor(value):
96
+ return torch.tensor(value)
97
+
98
+ ret = [to_tensor(arg) if is_primitive_type(arg) else arg for arg in args]
99
+ ret += [
100
+ to_tensor(kwargs[name]) if is_primitive_type(kwargs[name]) else kwargs[name] for name in names if name in kwargs
101
+ ]
102
+
103
+ # if kwargs is empty, append an empty dictionary at the end of the sample inputs to make exporter
104
+ # happy. This is because the exporter is confused with kwargs and dictionary inputs otherwise.
105
+ if not kwargs:
106
+ ret.append({})
107
+
108
+ return tuple(ret)
109
+
110
+
111
+ def infer_input_info(module: torch.nn.Module, *inputs, **kwargs):
112
+ """
113
+ Infer the input names and order from the arguments used to execute a PyTorch module for usage exporting
114
+ the model via torch.onnx.export.
115
+ Assumes model is on CPU. Use `module.to(torch.device('cpu'))` if it isn't.
116
+
117
+ Example usage:
118
+ input_names, inputs_as_tuple = infer_input_info(module, ...)
119
+ torch.onnx.export(module, inputs_as_type, 'model.onnx', input_names=input_names, output_names=[...], ...)
120
+
121
+ :param module: Module
122
+ :param inputs: Positional inputs
123
+ :param kwargs: Keyword argument inputs
124
+ :return: Tuple of ordered input names and input values. These can be used directly with torch.onnx.export as the
125
+ `input_names` and `inputs` arguments.
126
+ """
127
+ module_parameters = inspect.signature(module.forward).parameters.values()
128
+ input_names = _parse_inputs_for_onnx_export(module_parameters, inputs, kwargs)
129
+ inputs_as_tuple = _flatten_module_input(input_names, inputs, kwargs)
130
+
131
+ return input_names, inputs_as_tuple
File without changes
@@ -0,0 +1,37 @@
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+
5
+ import argparse
6
+ import os
7
+ import pathlib
8
+
9
+ import onnx
10
+
11
+
12
+ def optimize_qdq_model():
13
+ parser = argparse.ArgumentParser(
14
+ os.path.basename(__file__),
15
+ description="Update a QDQ format ONNX model to ensure optimal performance when executed using ONNX Runtime.",
16
+ )
17
+
18
+ parser.add_argument("input_model", type=pathlib.Path, help="Provide path to ONNX model to update.")
19
+ parser.add_argument("output_model", type=pathlib.Path, help="Provide path to write updated ONNX model to.")
20
+
21
+ args = parser.parse_args()
22
+
23
+ model = onnx.load(str(args.input_model.resolve(strict=True)))
24
+
25
+ # run QDQ model optimizations here
26
+
27
+ # Originally, the fixing up of DQ nodes with multiple consumers was implemented as one such optimization.
28
+ # That was moved to an ORT graph transformer.
29
+ print("As of ORT 1.15, the fixing up of DQ nodes with multiple consumers is done by an ORT graph transformer.")
30
+
31
+ # There are no optimizations being run currently but we expect that there may be in the future.
32
+
33
+ onnx.save(model, str(args.output_model.resolve()))
34
+
35
+
36
+ if __name__ == "__main__":
37
+ optimize_qdq_model()
@@ -0,0 +1,202 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ import os
5
+
6
+ # Check if the flatbuffers module is available. If not we cannot handle type reduction information in the config.
7
+ try:
8
+ import flatbuffers # noqa: F401
9
+
10
+ have_flatbuffers = True
11
+ from .ort_format_model import GloballyAllowedTypesOpTypeImplFilter, OperatorTypeUsageManager
12
+ except ImportError:
13
+ have_flatbuffers = False
14
+
15
+
16
+ def parse_config(config_file: str, enable_type_reduction: bool = False):
17
+ """
18
+ Parse the configuration file and return the required operators dictionary and an
19
+ OpTypeImplFilterInterface instance.
20
+
21
+ Configuration file lines can do the following:
22
+ 1. specify required operators
23
+ 2. specify globally allowed types for all operators
24
+ 3. specify what it means for no required operators to be specified
25
+
26
+ 1. Specifying required operators
27
+
28
+ The basic format for specifying required operators is `domain;opset1,opset2;op1,op2...`
29
+ e.g. `ai.onnx;11;Add,Cast,Clip,... for a single opset
30
+ `ai.onnx;11,12;Add,Cast,Clip,... for multiple opsets
31
+
32
+ note: Configuration information is accrued as the file is parsed. If an operator requires support from multiple
33
+ opsets that can be done with one entry for each opset, or one entry with multiple opsets in it.
34
+
35
+ If the configuration file is generated from ORT format models it may optionally contain JSON for per-operator
36
+ type reduction. The required types are generally listed per input and/or output of the operator.
37
+ The type information is in a map, with 'inputs' and 'outputs' keys. The value for 'inputs' or 'outputs' is a map
38
+ between the index number of the input/output and the required list of types.
39
+
40
+ For example, both the input and output types are relevant to ai.onnx:Cast.
41
+ Type information for input 0 and output 0 could look like this:
42
+ `{"inputs": {"0": ["float", "int32_t"]}, "outputs": {"0": ["float", "int64_t"]}}`
43
+
44
+ which is added directly after the operator name in the configuration file.
45
+ e.g.
46
+ `ai.onnx;12;Add,Cast{"inputs": {"0": ["float", "int32_t"]}, "outputs": {"0": ["float", "int64_t"]}},Concat`
47
+
48
+ If for example the types of inputs 0 and 1 were important, the entry may look like this (e.g. ai.onnx:Gather):
49
+ `{"inputs": {"0": ["float", "int32_t"], "1": ["int32_t"]}}`
50
+
51
+ Finally some operators do non-standard things and store their type information under a 'custom' key.
52
+ ai.onnx.OneHot is an example of this, where the three input types are combined into a triple.
53
+ `{"custom": [["float", "int64_t", "int64_t"], ["int64_t", "std::string", "int64_t"]]}`
54
+
55
+ 2. Specifying globally allowed types for all operators
56
+
57
+ The format for specifying globally allowed types for all operators is:
58
+ `!globally_allowed_types;T0,T1,...`
59
+
60
+ Ti should be a C++ scalar type supported by ONNX and ORT.
61
+ At most one globally allowed types specification is allowed.
62
+
63
+ Specifying per-operator type information and specifying globally allowed types are mutually exclusive - it is an
64
+ error to specify both.
65
+
66
+ 3. Specify what it means for no required operators to be specified
67
+
68
+ By default, if no required operators are specified, NO operators are required.
69
+
70
+ With the following line, if no required operators are specified, ALL operators are required:
71
+ `!no_ops_specified_means_all_ops_are_required`
72
+
73
+ :param config_file: Configuration file to parse
74
+ :param enable_type_reduction: Set to True to use the type information in the config.
75
+ If False the type information will be ignored.
76
+ If the flatbuffers module is unavailable type information will be ignored as the
77
+ type-based filtering has a dependency on the ORT flatbuffers schema.
78
+ :return: required_ops: Dictionary of domain:opset:[ops] for required operators. If None, all operators are
79
+ required.
80
+ op_type_impl_filter: OpTypeImplFilterInterface instance if type reduction is enabled, the flatbuffers
81
+ module is available, and type reduction information is present. None otherwise.
82
+ """
83
+
84
+ if not os.path.isfile(config_file):
85
+ raise ValueError(f"Configuration file {config_file} does not exist")
86
+
87
+ # only enable type reduction when flatbuffers is available
88
+ enable_type_reduction = enable_type_reduction and have_flatbuffers
89
+
90
+ required_ops = {}
91
+ no_ops_specified_means_all_ops_are_required = False
92
+ op_type_usage_manager = OperatorTypeUsageManager() if enable_type_reduction else None
93
+ has_op_type_reduction_info = False
94
+ globally_allowed_types = None
95
+
96
+ def process_non_op_line(line):
97
+ if not line or line.startswith("#"): # skip empty lines and comments
98
+ return True
99
+
100
+ if line.startswith("!globally_allowed_types;"): # handle globally allowed types
101
+ if enable_type_reduction:
102
+ nonlocal globally_allowed_types
103
+ if globally_allowed_types is not None:
104
+ raise RuntimeError("Globally allowed types were already specified.")
105
+ globally_allowed_types = {segment.strip() for segment in line.split(";")[1].split(",")}
106
+ return True
107
+
108
+ if line == "!no_ops_specified_means_all_ops_are_required": # handle all ops required line
109
+ nonlocal no_ops_specified_means_all_ops_are_required
110
+ no_ops_specified_means_all_ops_are_required = True
111
+ return True
112
+
113
+ return False
114
+
115
+ with open(config_file) as config:
116
+ for line in [orig_line.strip() for orig_line in config]:
117
+ if process_non_op_line(line):
118
+ continue
119
+
120
+ domain, opset_str, operators_str = (segment.strip() for segment in line.split(";"))
121
+ opsets = [int(s) for s in opset_str.split(",")]
122
+
123
+ # any type reduction information is serialized json that starts/ends with { and }.
124
+ # type info is optional for each operator.
125
+ if "{" in operators_str:
126
+ has_op_type_reduction_info = True
127
+
128
+ # parse the entries in the json dictionary with type info
129
+ operators = set()
130
+ cur = 0
131
+ end = len(operators_str)
132
+ while cur < end:
133
+ next_comma = operators_str.find(",", cur)
134
+ next_open_brace = operators_str.find("{", cur)
135
+
136
+ if next_comma == -1:
137
+ next_comma = end
138
+
139
+ # the json string starts with '{', so if that is found (next_open_brace != -1)
140
+ # before the next comma (which would be the start of the next operator if there is no type info
141
+ # for the current operator), we have type info to parse.
142
+ # e.g. need to handle extracting the operator name and type info for OpB and OpD,
143
+ # and just the operator names for OpA and OpC from this example string
144
+ # OpA,OpB{"inputs": {"0": ["float", "int32_t"]}},OpC,OpD{"outputs": {"0": ["int32_t"]}}
145
+ if 0 < next_open_brace < next_comma:
146
+ operator = operators_str[cur:next_open_brace].strip()
147
+ operators.add(operator)
148
+
149
+ # parse out the json dictionary with the type info by finding the closing brace that matches
150
+ # the opening brace
151
+ i = next_open_brace + 1
152
+ num_open_braces = 1
153
+ while num_open_braces > 0 and i < end:
154
+ if operators_str[i] == "{":
155
+ num_open_braces += 1
156
+ elif operators_str[i] == "}":
157
+ num_open_braces -= 1
158
+ i += 1
159
+
160
+ if num_open_braces != 0:
161
+ raise RuntimeError("Mismatched { and } in type string: " + operators_str[next_open_brace:])
162
+
163
+ if op_type_usage_manager:
164
+ type_str = operators_str[next_open_brace:i]
165
+ op_type_usage_manager.restore_from_config_entry(domain, operator, type_str)
166
+
167
+ cur = i + 1
168
+ else:
169
+ # comma or end of line is next
170
+ end_str = next_comma if next_comma != -1 else end
171
+ operators.add(operators_str[cur:end_str].strip())
172
+ cur = end_str + 1
173
+
174
+ else:
175
+ operators = {op.strip() for op in operators_str.split(",")}
176
+
177
+ for opset in opsets:
178
+ if domain not in required_ops:
179
+ required_ops[domain] = {opset: operators}
180
+ elif opset not in required_ops[domain]:
181
+ required_ops[domain][opset] = operators
182
+ else:
183
+ required_ops[domain][opset].update(operators)
184
+
185
+ if len(required_ops) == 0 and no_ops_specified_means_all_ops_are_required:
186
+ required_ops = None
187
+
188
+ op_type_impl_filter = None
189
+ if enable_type_reduction:
190
+ if not has_op_type_reduction_info:
191
+ op_type_usage_manager = None
192
+ if globally_allowed_types is not None and op_type_usage_manager is not None:
193
+ raise RuntimeError(
194
+ "Specifying globally allowed types and per-op type reduction info together is unsupported."
195
+ )
196
+
197
+ if globally_allowed_types is not None:
198
+ op_type_impl_filter = GloballyAllowedTypesOpTypeImplFilter(globally_allowed_types)
199
+ elif op_type_usage_manager is not None:
200
+ op_type_impl_filter = op_type_usage_manager.make_op_type_impl_filter()
201
+
202
+ return required_ops, op_type_impl_filter