onnxruntime-directml 1.20.0__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (305) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6508 -0
  4. onnxruntime/__init__.py +78 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +174 -0
  7. onnxruntime/backend/backend_rep.py +53 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  13. onnxruntime/capi/onnxruntime.dll +0 -0
  14. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  15. onnxruntime/capi/onnxruntime_inference_collection.py +1108 -0
  16. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  17. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  18. onnxruntime/capi/onnxruntime_validation.py +150 -0
  19. onnxruntime/capi/version_info.py +2 -0
  20. onnxruntime/datasets/__init__.py +17 -0
  21. onnxruntime/datasets/logreg_iris.onnx +0 -0
  22. onnxruntime/datasets/mul_1.onnx +0 -0
  23. onnxruntime/datasets/sigmoid.onnx +13 -0
  24. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  27. onnxruntime/quantization/__init__.py +16 -0
  28. onnxruntime/quantization/base_quantizer.py +532 -0
  29. onnxruntime/quantization/calibrate.py +1245 -0
  30. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  31. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  32. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  33. onnxruntime/quantization/execution_providers/qnn/preprocess.py +307 -0
  34. onnxruntime/quantization/execution_providers/qnn/quant_config.py +387 -0
  35. onnxruntime/quantization/fusions/__init__.py +3 -0
  36. onnxruntime/quantization/fusions/fusion.py +311 -0
  37. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  38. onnxruntime/quantization/fusions/fusion_layernorm.py +135 -0
  39. onnxruntime/quantization/matmul_4bits_quantizer.py +1480 -0
  40. onnxruntime/quantization/matmul_bnb4_quantizer.py +240 -0
  41. onnxruntime/quantization/onnx_model.py +580 -0
  42. onnxruntime/quantization/onnx_quantizer.py +1008 -0
  43. onnxruntime/quantization/operators/__init__.py +2 -0
  44. onnxruntime/quantization/operators/activation.py +119 -0
  45. onnxruntime/quantization/operators/argmax.py +18 -0
  46. onnxruntime/quantization/operators/attention.py +73 -0
  47. onnxruntime/quantization/operators/base_operator.py +26 -0
  48. onnxruntime/quantization/operators/binary_op.py +72 -0
  49. onnxruntime/quantization/operators/concat.py +62 -0
  50. onnxruntime/quantization/operators/conv.py +258 -0
  51. onnxruntime/quantization/operators/direct_q8.py +78 -0
  52. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  53. onnxruntime/quantization/operators/gather.py +64 -0
  54. onnxruntime/quantization/operators/gavgpool.py +62 -0
  55. onnxruntime/quantization/operators/gemm.py +166 -0
  56. onnxruntime/quantization/operators/lstm.py +117 -0
  57. onnxruntime/quantization/operators/matmul.py +231 -0
  58. onnxruntime/quantization/operators/maxpool.py +34 -0
  59. onnxruntime/quantization/operators/norm.py +40 -0
  60. onnxruntime/quantization/operators/pad.py +100 -0
  61. onnxruntime/quantization/operators/pooling.py +67 -0
  62. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  63. onnxruntime/quantization/operators/resize.py +34 -0
  64. onnxruntime/quantization/operators/softmax.py +74 -0
  65. onnxruntime/quantization/operators/split.py +63 -0
  66. onnxruntime/quantization/operators/where.py +87 -0
  67. onnxruntime/quantization/preprocess.py +141 -0
  68. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  69. onnxruntime/quantization/qdq_quantizer.py +1187 -0
  70. onnxruntime/quantization/quant_utils.py +891 -0
  71. onnxruntime/quantization/quantize.py +748 -0
  72. onnxruntime/quantization/registry.py +106 -0
  73. onnxruntime/quantization/shape_inference.py +187 -0
  74. onnxruntime/quantization/tensor_quant_overrides.py +516 -0
  75. onnxruntime/tools/__init__.py +10 -0
  76. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  77. onnxruntime/tools/convert_onnx_models_to_ort.py +377 -0
  78. onnxruntime/tools/file_utils.py +46 -0
  79. onnxruntime/tools/logger.py +11 -0
  80. onnxruntime/tools/make_dynamic_shape_fixed.py +72 -0
  81. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  82. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +33 -0
  83. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  84. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  85. onnxruntime/tools/mobile_helpers/usability_checker.py +739 -0
  86. onnxruntime/tools/offline_tuning.py +169 -0
  87. onnxruntime/tools/onnx_model_utils.py +413 -0
  88. onnxruntime/tools/onnx_randomizer.py +85 -0
  89. onnxruntime/tools/onnxruntime_test.py +164 -0
  90. onnxruntime/tools/optimize_onnx_model.py +55 -0
  91. onnxruntime/tools/ort_format_model/__init__.py +25 -0
  92. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +663 -0
  93. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  94. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  95. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  96. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  97. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  98. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  99. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  100. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  140. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  141. onnxruntime/tools/ort_format_model/types.py +84 -0
  142. onnxruntime/tools/ort_format_model/utils.py +62 -0
  143. onnxruntime/tools/pytorch_export_contrib_ops.py +108 -0
  144. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  145. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  146. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  147. onnxruntime/tools/reduced_build_config_parser.py +202 -0
  148. onnxruntime/tools/symbolic_shape_infer.py +3016 -0
  149. onnxruntime/tools/update_onnx_opset.py +31 -0
  150. onnxruntime/transformers/__init__.py +8 -0
  151. onnxruntime/transformers/affinity_helper.py +40 -0
  152. onnxruntime/transformers/benchmark.py +944 -0
  153. onnxruntime/transformers/benchmark_helper.py +646 -0
  154. onnxruntime/transformers/bert_perf_test.py +634 -0
  155. onnxruntime/transformers/bert_test_data.py +642 -0
  156. onnxruntime/transformers/compare_bert_results.py +246 -0
  157. onnxruntime/transformers/constants.py +47 -0
  158. onnxruntime/transformers/convert_generation.py +3124 -0
  159. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  160. onnxruntime/transformers/convert_to_packing_mode.py +387 -0
  161. onnxruntime/transformers/dynamo_onnx_helper.py +104 -0
  162. onnxruntime/transformers/float16.py +501 -0
  163. onnxruntime/transformers/fusion_attention.py +1235 -0
  164. onnxruntime/transformers/fusion_attention_clip.py +257 -0
  165. onnxruntime/transformers/fusion_attention_sam2.py +534 -0
  166. onnxruntime/transformers/fusion_attention_unet.py +1304 -0
  167. onnxruntime/transformers/fusion_attention_vae.py +301 -0
  168. onnxruntime/transformers/fusion_bart_attention.py +640 -0
  169. onnxruntime/transformers/fusion_base.py +137 -0
  170. onnxruntime/transformers/fusion_bias_add.py +58 -0
  171. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  172. onnxruntime/transformers/fusion_biassplitgelu.py +111 -0
  173. onnxruntime/transformers/fusion_conformer_attention.py +143 -0
  174. onnxruntime/transformers/fusion_embedlayer.py +811 -0
  175. onnxruntime/transformers/fusion_fastgelu.py +360 -0
  176. onnxruntime/transformers/fusion_gelu.py +259 -0
  177. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  178. onnxruntime/transformers/fusion_gemmfastgelu.py +122 -0
  179. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  180. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  181. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  182. onnxruntime/transformers/fusion_group_norm.py +179 -0
  183. onnxruntime/transformers/fusion_layernorm.py +465 -0
  184. onnxruntime/transformers/fusion_nhwc_conv.py +100 -0
  185. onnxruntime/transformers/fusion_options.py +340 -0
  186. onnxruntime/transformers/fusion_qordered_attention.py +421 -0
  187. onnxruntime/transformers/fusion_qordered_gelu.py +119 -0
  188. onnxruntime/transformers/fusion_qordered_layernorm.py +123 -0
  189. onnxruntime/transformers/fusion_qordered_matmul.py +217 -0
  190. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  191. onnxruntime/transformers/fusion_reshape.py +173 -0
  192. onnxruntime/transformers/fusion_rotary_attention.py +1592 -0
  193. onnxruntime/transformers/fusion_shape.py +110 -0
  194. onnxruntime/transformers/fusion_simplified_layernorm.py +159 -0
  195. onnxruntime/transformers/fusion_skip_group_norm.py +255 -0
  196. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  197. onnxruntime/transformers/fusion_transpose.py +168 -0
  198. onnxruntime/transformers/fusion_utils.py +307 -0
  199. onnxruntime/transformers/huggingface_models.py +167 -0
  200. onnxruntime/transformers/import_utils.py +20 -0
  201. onnxruntime/transformers/io_binding_helper.py +442 -0
  202. onnxruntime/transformers/large_model_exporter.py +395 -0
  203. onnxruntime/transformers/machine_info.py +221 -0
  204. onnxruntime/transformers/metrics.py +164 -0
  205. onnxruntime/transformers/models/bart/__init__.py +12 -0
  206. onnxruntime/transformers/models/bart/export.py +98 -0
  207. onnxruntime/transformers/models/bert/__init__.py +12 -0
  208. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  209. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  210. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  211. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +561 -0
  212. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1032 -0
  213. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  214. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  215. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  216. onnxruntime/transformers/models/llama/__init__.py +12 -0
  217. onnxruntime/transformers/models/llama/benchmark.py +703 -0
  218. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  219. onnxruntime/transformers/models/llama/benchmark_e2e.py +606 -0
  220. onnxruntime/transformers/models/llama/convert_to_onnx.py +1027 -0
  221. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  222. onnxruntime/transformers/models/llama/llama_inputs.py +503 -0
  223. onnxruntime/transformers/models/llama/llama_parity.py +309 -0
  224. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  225. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  226. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  227. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  228. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  229. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  230. onnxruntime/transformers/models/longformer/longformer_helper.py +77 -0
  231. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  232. onnxruntime/transformers/models/phi2/convert_to_onnx.py +576 -0
  233. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  234. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  235. onnxruntime/transformers/models/sam2/benchmark_sam2.py +625 -0
  236. onnxruntime/transformers/models/sam2/convert_to_onnx.py +260 -0
  237. onnxruntime/transformers/models/sam2/image_decoder.py +273 -0
  238. onnxruntime/transformers/models/sam2/image_encoder.py +186 -0
  239. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  240. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  241. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  242. onnxruntime/transformers/models/sam2/sam2_demo.py +322 -0
  243. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +280 -0
  244. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  245. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  246. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1429 -0
  247. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  248. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +102 -0
  249. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +268 -0
  250. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  251. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1319 -0
  252. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1181 -0
  253. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +296 -0
  254. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +388 -0
  255. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  256. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  257. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  258. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +350 -0
  259. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  260. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  261. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  262. onnxruntime/transformers/models/t5/__init__.py +12 -0
  263. onnxruntime/transformers/models/t5/convert_to_onnx.py +278 -0
  264. onnxruntime/transformers/models/t5/past_helper.py +150 -0
  265. onnxruntime/transformers/models/t5/t5_decoder.py +438 -0
  266. onnxruntime/transformers/models/t5/t5_encoder.py +171 -0
  267. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +299 -0
  268. onnxruntime/transformers/models/t5/t5_helper.py +272 -0
  269. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  270. onnxruntime/transformers/models/whisper/benchmark.py +610 -0
  271. onnxruntime/transformers/models/whisper/benchmark_all.py +528 -0
  272. onnxruntime/transformers/models/whisper/convert_to_onnx.py +536 -0
  273. onnxruntime/transformers/models/whisper/whisper_chain.py +329 -0
  274. onnxruntime/transformers/models/whisper/whisper_decoder.py +402 -0
  275. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  276. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +306 -0
  277. onnxruntime/transformers/models/whisper/whisper_helper.py +524 -0
  278. onnxruntime/transformers/models/whisper/whisper_openai_helper.py +84 -0
  279. onnxruntime/transformers/onnx_exporter.py +717 -0
  280. onnxruntime/transformers/onnx_model.py +1569 -0
  281. onnxruntime/transformers/onnx_model_bart.py +142 -0
  282. onnxruntime/transformers/onnx_model_bert.py +481 -0
  283. onnxruntime/transformers/onnx_model_bert_keras.py +475 -0
  284. onnxruntime/transformers/onnx_model_bert_tf.py +589 -0
  285. onnxruntime/transformers/onnx_model_clip.py +40 -0
  286. onnxruntime/transformers/onnx_model_conformer.py +33 -0
  287. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  288. onnxruntime/transformers/onnx_model_phi.py +930 -0
  289. onnxruntime/transformers/onnx_model_sam2.py +138 -0
  290. onnxruntime/transformers/onnx_model_t5.py +791 -0
  291. onnxruntime/transformers/onnx_model_tnlr.py +227 -0
  292. onnxruntime/transformers/onnx_model_unet.py +259 -0
  293. onnxruntime/transformers/onnx_model_vae.py +43 -0
  294. onnxruntime/transformers/onnx_utils.py +55 -0
  295. onnxruntime/transformers/optimizer.py +612 -0
  296. onnxruntime/transformers/profiler.py +725 -0
  297. onnxruntime/transformers/quantize_helper.py +76 -0
  298. onnxruntime/transformers/shape_infer_helper.py +122 -0
  299. onnxruntime/transformers/shape_optimizer.py +401 -0
  300. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  301. onnxruntime_directml-1.20.0.dist-info/METADATA +187 -0
  302. onnxruntime_directml-1.20.0.dist-info/RECORD +305 -0
  303. onnxruntime_directml-1.20.0.dist-info/WHEEL +5 -0
  304. onnxruntime_directml-1.20.0.dist-info/entry_points.txt +2 -0
  305. onnxruntime_directml-1.20.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,137 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+ from collections import defaultdict
6
+ from logging import getLogger
7
+ from typing import Any, Dict, List, Optional, Sequence, Union
8
+
9
+ import numpy as np
10
+ from onnx import NodeProto, helper
11
+ from onnx_model import OnnxModel
12
+
13
+ logger = getLogger(__name__)
14
+
15
+
16
+ class Fusion:
17
+ """
18
+ Base class for Graph Fusion
19
+ """
20
+
21
+ def __init__(
22
+ self,
23
+ model: OnnxModel,
24
+ fused_op_type: str,
25
+ search_op_types: Union[str, List[str]],
26
+ description: str = "",
27
+ ):
28
+ self.search_op_types: List[str] = [search_op_types] if isinstance(search_op_types, str) else search_op_types
29
+ self.fused_op_type: str = fused_op_type
30
+ self.description: str = f"{fused_op_type}({description})" if description else fused_op_type
31
+ self.model: OnnxModel = model
32
+ self.nodes_to_remove: List = []
33
+ self.nodes_to_add: List = []
34
+ self.prune_graph: bool = False
35
+ self.node_name_to_graph_name: dict = {}
36
+ self.this_graph_name: Optional[str] = None
37
+ # It is optional that subclass updates fused_count since we will also check nodes_to_add to get counter.
38
+ self.fused_count: defaultdict = defaultdict(int)
39
+
40
+ def increase_counter(self, fused_op_name: str):
41
+ """
42
+ Increase counter of a fused operator.
43
+ """
44
+ self.fused_count[fused_op_name] += 1
45
+
46
+ def fuse(
47
+ self,
48
+ node: NodeProto,
49
+ input_name_to_nodes: Dict[str, List[NodeProto]],
50
+ output_name_to_node: Dict[str, NodeProto],
51
+ ):
52
+ """Interface for fusion that starts from a node"""
53
+ raise NotImplementedError
54
+
55
+ def apply(self):
56
+ """
57
+ Apply graph fusion on the whole model graph.
58
+ It searched nodes of given operators, and start fusion on each of those nodes.
59
+ """
60
+ logger.debug(f"start {self.description} fusion...")
61
+ input_name_to_nodes = self.model.input_name_to_nodes()
62
+ output_name_to_node = self.model.output_name_to_node()
63
+
64
+ # This assumes that two search ops will not be fused at same time!
65
+ for search_op_type in self.search_op_types:
66
+ for node in self.model.get_nodes_by_op_type(search_op_type):
67
+ graph = self.model.get_graph_by_node(node)
68
+ if graph is None:
69
+ raise Exception("Can not find node in any graph")
70
+ self.this_graph_name = graph.name
71
+ self.fuse(node, input_name_to_nodes, output_name_to_node)
72
+
73
+ op_list = [node.op_type for node in self.nodes_to_add]
74
+ if self.fused_count:
75
+ for key, value in self.fused_count.items():
76
+ if value:
77
+ logger.info(f"Fused {key}: {value}")
78
+ else:
79
+ count = op_list.count(self.fused_op_type)
80
+ if count > 0:
81
+ logger.info(f"Fused {self.description}: {count}")
82
+
83
+ self.model.remove_nodes(self.nodes_to_remove)
84
+ self.model.add_nodes(self.nodes_to_add, self.node_name_to_graph_name)
85
+
86
+ if self.prune_graph:
87
+ self.model.prune_graph()
88
+ elif self.nodes_to_remove or self.nodes_to_add:
89
+ self.model.update_graph()
90
+
91
+ def add_initializer(self, name: str, data_type: int, dims: Sequence[int], vals: Any, raw: bool = True):
92
+ if raw:
93
+ np_type = helper.tensor_dtype_to_np_dtype(data_type)
94
+ if not isinstance(vals, np.ndarray):
95
+ bytes = np.array(vals, dtype=np_type).tobytes()
96
+ else:
97
+ bytes = vals.astype(np_type).tobytes()
98
+ tensor = helper.make_tensor(
99
+ name=name,
100
+ data_type=data_type,
101
+ dims=dims,
102
+ vals=bytes,
103
+ raw=True,
104
+ )
105
+ else:
106
+ tensor = helper.make_tensor(
107
+ name=name,
108
+ data_type=data_type,
109
+ dims=dims,
110
+ vals=vals,
111
+ raw=False,
112
+ )
113
+
114
+ self.model.add_initializer(tensor, self.this_graph_name)
115
+ return tensor
116
+
117
+ def add_nodes_to_remove(self, nodes: List[NodeProto]):
118
+ # Some nodes are shared between paths (e.g. rotary embedding nodes in the Q and K paths).
119
+ # When path A is fused, its shared nodes are added to `self.nodes_to_remove`. But when path B
120
+ # is fused, its shared nodes are also added to `self.nodes_to_remove`. When the nodes are
121
+ # iteratively removed from `self.nodes_to_remove`, path A's shared nodes are removed first.
122
+ # Since path A's shared nodes are removed, path B's shared nodes are not removed because they
123
+ # were previously removed for path A. This causes an error to print in remove_node that a node
124
+ # has failed to be removed.
125
+ #
126
+ # To avoid this error, we pre-emptively check if the shared nodes are already in `self.nodes_to_remove`.
127
+ # We could alternatively convert `self.nodes_to_remove` to a set to avoid this issue, but there could
128
+ # be scenarios where the nodes need to be removed in a specific order and converting to a set would
129
+ # lose this order.
130
+ for node in nodes:
131
+ if node not in self.nodes_to_remove:
132
+ self.nodes_to_remove.append(node)
133
+
134
+ def add_nodes_to_remove_with_nodes_to_keep(self, nodes: List[NodeProto], nodes_to_keep: List[NodeProto]):
135
+ for node in nodes:
136
+ if node not in self.nodes_to_remove and node not in nodes_to_keep:
137
+ self.nodes_to_remove.append(node)
@@ -0,0 +1,58 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+ from logging import getLogger
6
+ from typing import Dict
7
+
8
+ from fusion_base import Fusion
9
+ from numpy import ndarray
10
+ from onnx import helper
11
+ from onnx_model import OnnxModel
12
+
13
+ logger = getLogger(__name__)
14
+
15
+
16
+ class FusionBiasAdd(Fusion):
17
+ def __init__(self, model: OnnxModel):
18
+ super().__init__(model, "BiasAdd", "Add")
19
+
20
+ def fuse(self, add_node, input_name_to_nodes: Dict, output_name_to_node: Dict):
21
+ """
22
+ Fuse Add bias and Add skip connection into BiasAdd
23
+ """
24
+
25
+ nodes = self.model.match_parent_path(
26
+ add_node,
27
+ ["Add", "MatMul", "BiasSplitGelu", "MatMul", "SkipLayerNormalization"],
28
+ [0, None, 0, 0, 0],
29
+ output_name_to_node,
30
+ )
31
+
32
+ if nodes is None:
33
+ return
34
+
35
+ bias_node = nodes[0]
36
+ skip_layer_norm = nodes[-1]
37
+
38
+ # Check skip connection is from SkipLayerNormalization output
39
+ if add_node.input[1] not in skip_layer_norm.output:
40
+ return
41
+
42
+ bias_index, bias_value = self.model.get_constant_input(bias_node)
43
+ if not (isinstance(bias_index, int) and (bias_value is not None) and isinstance(bias_value, ndarray)):
44
+ return
45
+ if bias_value.ndim != 1:
46
+ return
47
+
48
+ self.nodes_to_remove.extend([add_node, bias_node])
49
+ node_name = self.model.create_node_name("BiasAdd")
50
+ fused_node = helper.make_node(
51
+ "BiasAdd",
52
+ inputs=[bias_node.input[1 - bias_index], bias_node.input[bias_index], add_node.input[1]],
53
+ outputs=[add_node.output[0]],
54
+ name=node_name,
55
+ )
56
+ fused_node.domain = "com.microsoft"
57
+ self.nodes_to_add.append(fused_node)
58
+ self.node_name_to_graph_name[node_name] = self.this_graph_name
@@ -0,0 +1,66 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+
6
+ from logging import getLogger
7
+
8
+ from fusion_base import Fusion
9
+ from fusion_utils import NumpyHelper
10
+ from onnx import helper
11
+ from onnx_model import OnnxModel
12
+
13
+ logger = getLogger(__name__)
14
+
15
+
16
+ class FusionBiasGelu(Fusion):
17
+ def __init__(self, model: OnnxModel, is_fastgelu):
18
+ if is_fastgelu:
19
+ super().__init__(model, "FastGelu", "FastGelu", "add bias")
20
+ else:
21
+ super().__init__(model, "BiasGelu", "Gelu")
22
+
23
+ def fuse(self, node, input_name_to_nodes, output_name_to_node):
24
+ gelu_op_type = node.op_type
25
+ fuse_op_type = "BiasGelu" if gelu_op_type == "Gelu" else "FastGelu"
26
+
27
+ if len(node.input) != 1:
28
+ return
29
+
30
+ nodes = self.model.match_parent_path(node, ["Add", "MatMul"], [0, None])
31
+ if nodes is None:
32
+ return
33
+ (add, matmul) = nodes
34
+
35
+ bias_weight = None
36
+ # bias should be one dimension
37
+ bias_index = -1
38
+ for i, input in enumerate(add.input):
39
+ initializer = self.model.get_initializer(input)
40
+ if initializer is None:
41
+ continue
42
+ bias_index = i
43
+ bias_weight = NumpyHelper.to_array(initializer)
44
+ break
45
+ if bias_weight is None:
46
+ return
47
+ if len(bias_weight.shape) != 1:
48
+ return
49
+
50
+ subgraph_nodes = [node, add]
51
+ if not self.model.is_safe_to_fuse_nodes(
52
+ subgraph_nodes, [node.output[0]], input_name_to_nodes, output_name_to_node
53
+ ):
54
+ return
55
+
56
+ self.nodes_to_remove.extend(subgraph_nodes)
57
+
58
+ fused_node = helper.make_node(
59
+ fuse_op_type,
60
+ inputs=[matmul.output[0], add.input[bias_index]],
61
+ outputs=node.output,
62
+ name=self.model.create_node_name(fuse_op_type, gelu_op_type + "_AddBias_"),
63
+ )
64
+ fused_node.domain = "com.microsoft"
65
+ self.nodes_to_add.append(fused_node)
66
+ self.node_name_to_graph_name[fused_node.name] = self.this_graph_name
@@ -0,0 +1,111 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+ from logging import getLogger
6
+ from typing import Dict
7
+
8
+ from fusion_base import Fusion
9
+ from onnx import helper
10
+ from onnx_model import OnnxModel
11
+
12
+ logger = getLogger(__name__)
13
+
14
+
15
+ class FusionBiasSplitGelu(Fusion):
16
+ def __init__(self, model: OnnxModel):
17
+ super().__init__(model, "BiasSplitGelu", "Gelu")
18
+
19
+ def fuse(self, gelu_node, input_name_to_nodes: Dict, output_name_to_node: Dict):
20
+ """
21
+ [root] --->Add --------------------> Slice ---------------> Mul -->
22
+ | ^ ^
23
+ | | |
24
+ +----------------------------+---Slice --> Gelu---+
25
+ | | ^
26
+ | |-----|
27
+ | | |
28
+ | Mul Mul
29
+ | ^ ^
30
+ v | |
31
+ Shape ---> Gather --> Add --> Div --+
32
+ """
33
+ if gelu_node.output[0] not in input_name_to_nodes:
34
+ return
35
+ children = input_name_to_nodes[gelu_node.output[0]]
36
+ if len(children) != 1 or children[0].op_type != "Mul":
37
+ return
38
+ mul_after_gelu = children[0]
39
+
40
+ slice_before_gelu = self.model.match_parent(gelu_node, "Slice", 0, output_name_to_node)
41
+ if slice_before_gelu is None:
42
+ return
43
+
44
+ if self.model.find_constant_input(slice_before_gelu, -1, delta=0.001) != 3:
45
+ return
46
+
47
+ add_output = slice_before_gelu.input[0]
48
+
49
+ start_index_nodes = self.model.match_parent_path(
50
+ slice_before_gelu,
51
+ ["Div", "Add", "Gather", "Shape", "Add"],
52
+ [1, 0, 0, 0, 0],
53
+ output_name_to_node, # Mul(1) is optional
54
+ )
55
+ if start_index_nodes is None:
56
+ start_index_nodes = self.model.match_parent_path(
57
+ slice_before_gelu,
58
+ ["Mul", "Div", "Add", "Gather", "Shape", "Add"],
59
+ [1, 0, 0, 0, 0, 0],
60
+ output_name_to_node,
61
+ )
62
+
63
+ if start_index_nodes is None or start_index_nodes[-2].input[0] != add_output:
64
+ return
65
+
66
+ end_index_nodes = self.model.match_parent_path(slice_before_gelu, ["Mul", "Div"], [2, 0], output_name_to_node)
67
+
68
+ if (
69
+ end_index_nodes is None or end_index_nodes[1] not in start_index_nodes
70
+ ): # the Div is parent of both two Mul nodes
71
+ return
72
+
73
+ slice_before_mul = self.model.match_parent(mul_after_gelu, "Slice", 0, output_name_to_node)
74
+ if slice_before_mul is None:
75
+ return
76
+
77
+ if (
78
+ slice_before_mul.input[2] != slice_before_gelu.input[1]
79
+ ): # end index of slice_before_mul is start index of slice_before_gelu
80
+ return
81
+
82
+ subgraph_nodes = [
83
+ *start_index_nodes,
84
+ end_index_nodes[0],
85
+ mul_after_gelu,
86
+ gelu_node,
87
+ slice_before_mul,
88
+ slice_before_gelu,
89
+ ]
90
+ subgraph_output = mul_after_gelu.output[0]
91
+ if not self.model.is_safe_to_fuse_nodes(
92
+ subgraph_nodes, [subgraph_output], input_name_to_nodes, output_name_to_node
93
+ ):
94
+ logger.info("Skip fuse BiasSplitGelu since it is not safe to fuse the subgraph.")
95
+ return
96
+
97
+ add_node = start_index_nodes[-1]
98
+ bias_index, _value = self.model.get_constant_input(add_node)
99
+ if not isinstance(bias_index, int):
100
+ return
101
+ self.nodes_to_remove.extend(subgraph_nodes)
102
+ node_name = self.model.create_node_name("BiasSplitGelu", name_prefix="BiasSplitGelu")
103
+ fused_node = helper.make_node(
104
+ "BiasSplitGelu",
105
+ inputs=[add_node.input[1 - bias_index], add_node.input[bias_index]],
106
+ outputs=[subgraph_output],
107
+ name=node_name,
108
+ )
109
+ fused_node.domain = "com.microsoft"
110
+ self.nodes_to_add.append(fused_node)
111
+ self.node_name_to_graph_name[node_name] = self.this_graph_name
@@ -0,0 +1,143 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+ import logging
6
+
7
+ from fusion_attention import AttentionMask, FusionAttention
8
+ from onnx_model import OnnxModel
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ class FusionConformerAttention(FusionAttention):
14
+ """
15
+ Fuse Conformer Attention subgraph into one MultiHeadAttention node.
16
+ """
17
+
18
+ def __init__(
19
+ self,
20
+ model: OnnxModel,
21
+ hidden_size: int,
22
+ num_heads: int,
23
+ attention_mask: AttentionMask,
24
+ ):
25
+ super().__init__(model, hidden_size, num_heads, attention_mask)
26
+
27
+ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
28
+ # SkipLayerNormalization has two inputs, and one of them is the root input for attention.
29
+ qkv_nodes = self.model.match_parent_path(
30
+ normalize_node,
31
+ ["Add", "MatMul", "Reshape", "Transpose", "MatMul"],
32
+ [1, 1, 0, 0, 0],
33
+ )
34
+ if qkv_nodes is not None:
35
+ (
36
+ _,
37
+ _,
38
+ reshape_qkv,
39
+ transpose_qkv,
40
+ matmul_qkv,
41
+ ) = qkv_nodes
42
+ else:
43
+ logger.debug("fuse_conformer_attention: failed to match qkv path")
44
+ return
45
+
46
+ v_nodes = self.model.match_parent_path(
47
+ matmul_qkv,
48
+ ["Concat", "Transpose", "Reshape", "Add", "MatMul"],
49
+ [1, 1, 0, 0, 1],
50
+ )
51
+
52
+ add_v = None
53
+ if v_nodes is not None:
54
+ (concat_v, _, _, add_v, matmul_v) = v_nodes
55
+ concat_parent = self.model.get_parent(concat_v, 0, None)
56
+ present_v = concat_v.output[0]
57
+ past_v = concat_parent.output[0]
58
+ else:
59
+ logger.debug("fuse_conformer_attention: failed to match v path")
60
+ return
61
+
62
+ qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Add", "MatMul"], [0, 0, 0])
63
+
64
+ if qk_nodes is not None:
65
+ _, add_qk, matmul_qk = qk_nodes
66
+ else:
67
+ logger.debug("fuse_conformer_attention: failed to match qk path")
68
+ return
69
+
70
+ q_nodes = self.model.match_parent_path(
71
+ matmul_qk,
72
+ ["Div", "Transpose", "Reshape", "Add", "MatMul"],
73
+ [0, 0, 0, 0, 1],
74
+ )
75
+ if q_nodes is not None:
76
+ _, _, reshape_q, add_q, matmul_q = q_nodes
77
+ else:
78
+ logger.debug("fuse_conformer_attention: failed to match q path")
79
+ return
80
+
81
+ k_nodes = self.model.match_parent_path(
82
+ matmul_qk,
83
+ ["Transpose", "Concat", "Transpose", "Reshape", "Add", "MatMul"],
84
+ [1, 0, 1, 0, 0, 1],
85
+ )
86
+
87
+ matmul_k = None
88
+ if k_nodes is not None:
89
+ _, concat_k, _, _, add_k, matmul_k = k_nodes
90
+ concat_parent = self.model.get_parent(concat_k, 0, None)
91
+ past_k = concat_parent.output[0]
92
+ present_k = concat_k.output[0]
93
+ else:
94
+ logger.debug("fuse_conformer_attention: failed to match k path")
95
+ return
96
+
97
+ attention_last_node = reshape_qkv
98
+ num_heads, hidden_size = self.get_num_heads_and_hidden_size(reshape_q)
99
+
100
+ if num_heads <= 0 or hidden_size <= 0 or (hidden_size % num_heads) != 0:
101
+ logger.debug("fuse_conformer_attention: failed to detect num_heads or hidden_size")
102
+ return
103
+
104
+ new_node = self.create_multihead_attention_node(
105
+ matmul_q,
106
+ matmul_k,
107
+ matmul_v,
108
+ add_q,
109
+ add_k,
110
+ add_v,
111
+ num_heads,
112
+ hidden_size,
113
+ attention_last_node.output[0],
114
+ add_qk=add_qk.input[1],
115
+ past_k=past_k,
116
+ past_v=past_v,
117
+ present_k=present_k,
118
+ present_v=present_v,
119
+ )
120
+
121
+ if new_node is None:
122
+ logger.debug("fuse_conformer_attention: MultiHeadAttention node creation failed")
123
+ return
124
+
125
+ self.nodes_to_add.append(new_node)
126
+ self.node_name_to_graph_name[new_node.name] = self.this_graph_name
127
+
128
+ self.nodes_to_remove.extend([attention_last_node, transpose_qkv, matmul_qkv])
129
+ self.nodes_to_remove.extend(qk_nodes)
130
+
131
+ # When using multihead attention, keep MatMul nodes in original graph
132
+ if q_nodes[-1].op_type == "MatMul":
133
+ q_nodes.pop()
134
+ if k_nodes[-1].op_type == "MatMul":
135
+ k_nodes.pop()
136
+ if v_nodes[-1].op_type == "MatMul":
137
+ v_nodes.pop()
138
+
139
+ self.nodes_to_remove.extend(k_nodes)
140
+ self.nodes_to_remove.extend(v_nodes)
141
+
142
+ # Use prune graph to remove mask nodes since they are shared by all attention nodes.
143
+ self.prune_graph = True