onnxruntime-directml 1.20.0__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (305) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6508 -0
  4. onnxruntime/__init__.py +78 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +174 -0
  7. onnxruntime/backend/backend_rep.py +53 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  13. onnxruntime/capi/onnxruntime.dll +0 -0
  14. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  15. onnxruntime/capi/onnxruntime_inference_collection.py +1108 -0
  16. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  17. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  18. onnxruntime/capi/onnxruntime_validation.py +150 -0
  19. onnxruntime/capi/version_info.py +2 -0
  20. onnxruntime/datasets/__init__.py +17 -0
  21. onnxruntime/datasets/logreg_iris.onnx +0 -0
  22. onnxruntime/datasets/mul_1.onnx +0 -0
  23. onnxruntime/datasets/sigmoid.onnx +13 -0
  24. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  27. onnxruntime/quantization/__init__.py +16 -0
  28. onnxruntime/quantization/base_quantizer.py +532 -0
  29. onnxruntime/quantization/calibrate.py +1245 -0
  30. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  31. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  32. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  33. onnxruntime/quantization/execution_providers/qnn/preprocess.py +307 -0
  34. onnxruntime/quantization/execution_providers/qnn/quant_config.py +387 -0
  35. onnxruntime/quantization/fusions/__init__.py +3 -0
  36. onnxruntime/quantization/fusions/fusion.py +311 -0
  37. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  38. onnxruntime/quantization/fusions/fusion_layernorm.py +135 -0
  39. onnxruntime/quantization/matmul_4bits_quantizer.py +1480 -0
  40. onnxruntime/quantization/matmul_bnb4_quantizer.py +240 -0
  41. onnxruntime/quantization/onnx_model.py +580 -0
  42. onnxruntime/quantization/onnx_quantizer.py +1008 -0
  43. onnxruntime/quantization/operators/__init__.py +2 -0
  44. onnxruntime/quantization/operators/activation.py +119 -0
  45. onnxruntime/quantization/operators/argmax.py +18 -0
  46. onnxruntime/quantization/operators/attention.py +73 -0
  47. onnxruntime/quantization/operators/base_operator.py +26 -0
  48. onnxruntime/quantization/operators/binary_op.py +72 -0
  49. onnxruntime/quantization/operators/concat.py +62 -0
  50. onnxruntime/quantization/operators/conv.py +258 -0
  51. onnxruntime/quantization/operators/direct_q8.py +78 -0
  52. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  53. onnxruntime/quantization/operators/gather.py +64 -0
  54. onnxruntime/quantization/operators/gavgpool.py +62 -0
  55. onnxruntime/quantization/operators/gemm.py +166 -0
  56. onnxruntime/quantization/operators/lstm.py +117 -0
  57. onnxruntime/quantization/operators/matmul.py +231 -0
  58. onnxruntime/quantization/operators/maxpool.py +34 -0
  59. onnxruntime/quantization/operators/norm.py +40 -0
  60. onnxruntime/quantization/operators/pad.py +100 -0
  61. onnxruntime/quantization/operators/pooling.py +67 -0
  62. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  63. onnxruntime/quantization/operators/resize.py +34 -0
  64. onnxruntime/quantization/operators/softmax.py +74 -0
  65. onnxruntime/quantization/operators/split.py +63 -0
  66. onnxruntime/quantization/operators/where.py +87 -0
  67. onnxruntime/quantization/preprocess.py +141 -0
  68. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  69. onnxruntime/quantization/qdq_quantizer.py +1187 -0
  70. onnxruntime/quantization/quant_utils.py +891 -0
  71. onnxruntime/quantization/quantize.py +748 -0
  72. onnxruntime/quantization/registry.py +106 -0
  73. onnxruntime/quantization/shape_inference.py +187 -0
  74. onnxruntime/quantization/tensor_quant_overrides.py +516 -0
  75. onnxruntime/tools/__init__.py +10 -0
  76. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  77. onnxruntime/tools/convert_onnx_models_to_ort.py +377 -0
  78. onnxruntime/tools/file_utils.py +46 -0
  79. onnxruntime/tools/logger.py +11 -0
  80. onnxruntime/tools/make_dynamic_shape_fixed.py +72 -0
  81. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  82. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +33 -0
  83. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  84. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  85. onnxruntime/tools/mobile_helpers/usability_checker.py +739 -0
  86. onnxruntime/tools/offline_tuning.py +169 -0
  87. onnxruntime/tools/onnx_model_utils.py +413 -0
  88. onnxruntime/tools/onnx_randomizer.py +85 -0
  89. onnxruntime/tools/onnxruntime_test.py +164 -0
  90. onnxruntime/tools/optimize_onnx_model.py +55 -0
  91. onnxruntime/tools/ort_format_model/__init__.py +25 -0
  92. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +663 -0
  93. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  94. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  95. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  96. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  97. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  98. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  99. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  100. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  140. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  141. onnxruntime/tools/ort_format_model/types.py +84 -0
  142. onnxruntime/tools/ort_format_model/utils.py +62 -0
  143. onnxruntime/tools/pytorch_export_contrib_ops.py +108 -0
  144. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  145. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  146. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  147. onnxruntime/tools/reduced_build_config_parser.py +202 -0
  148. onnxruntime/tools/symbolic_shape_infer.py +3016 -0
  149. onnxruntime/tools/update_onnx_opset.py +31 -0
  150. onnxruntime/transformers/__init__.py +8 -0
  151. onnxruntime/transformers/affinity_helper.py +40 -0
  152. onnxruntime/transformers/benchmark.py +944 -0
  153. onnxruntime/transformers/benchmark_helper.py +646 -0
  154. onnxruntime/transformers/bert_perf_test.py +634 -0
  155. onnxruntime/transformers/bert_test_data.py +642 -0
  156. onnxruntime/transformers/compare_bert_results.py +246 -0
  157. onnxruntime/transformers/constants.py +47 -0
  158. onnxruntime/transformers/convert_generation.py +3124 -0
  159. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  160. onnxruntime/transformers/convert_to_packing_mode.py +387 -0
  161. onnxruntime/transformers/dynamo_onnx_helper.py +104 -0
  162. onnxruntime/transformers/float16.py +501 -0
  163. onnxruntime/transformers/fusion_attention.py +1235 -0
  164. onnxruntime/transformers/fusion_attention_clip.py +257 -0
  165. onnxruntime/transformers/fusion_attention_sam2.py +534 -0
  166. onnxruntime/transformers/fusion_attention_unet.py +1304 -0
  167. onnxruntime/transformers/fusion_attention_vae.py +301 -0
  168. onnxruntime/transformers/fusion_bart_attention.py +640 -0
  169. onnxruntime/transformers/fusion_base.py +137 -0
  170. onnxruntime/transformers/fusion_bias_add.py +58 -0
  171. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  172. onnxruntime/transformers/fusion_biassplitgelu.py +111 -0
  173. onnxruntime/transformers/fusion_conformer_attention.py +143 -0
  174. onnxruntime/transformers/fusion_embedlayer.py +811 -0
  175. onnxruntime/transformers/fusion_fastgelu.py +360 -0
  176. onnxruntime/transformers/fusion_gelu.py +259 -0
  177. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  178. onnxruntime/transformers/fusion_gemmfastgelu.py +122 -0
  179. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  180. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  181. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  182. onnxruntime/transformers/fusion_group_norm.py +179 -0
  183. onnxruntime/transformers/fusion_layernorm.py +465 -0
  184. onnxruntime/transformers/fusion_nhwc_conv.py +100 -0
  185. onnxruntime/transformers/fusion_options.py +340 -0
  186. onnxruntime/transformers/fusion_qordered_attention.py +421 -0
  187. onnxruntime/transformers/fusion_qordered_gelu.py +119 -0
  188. onnxruntime/transformers/fusion_qordered_layernorm.py +123 -0
  189. onnxruntime/transformers/fusion_qordered_matmul.py +217 -0
  190. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  191. onnxruntime/transformers/fusion_reshape.py +173 -0
  192. onnxruntime/transformers/fusion_rotary_attention.py +1592 -0
  193. onnxruntime/transformers/fusion_shape.py +110 -0
  194. onnxruntime/transformers/fusion_simplified_layernorm.py +159 -0
  195. onnxruntime/transformers/fusion_skip_group_norm.py +255 -0
  196. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  197. onnxruntime/transformers/fusion_transpose.py +168 -0
  198. onnxruntime/transformers/fusion_utils.py +307 -0
  199. onnxruntime/transformers/huggingface_models.py +167 -0
  200. onnxruntime/transformers/import_utils.py +20 -0
  201. onnxruntime/transformers/io_binding_helper.py +442 -0
  202. onnxruntime/transformers/large_model_exporter.py +395 -0
  203. onnxruntime/transformers/machine_info.py +221 -0
  204. onnxruntime/transformers/metrics.py +164 -0
  205. onnxruntime/transformers/models/bart/__init__.py +12 -0
  206. onnxruntime/transformers/models/bart/export.py +98 -0
  207. onnxruntime/transformers/models/bert/__init__.py +12 -0
  208. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  209. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  210. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  211. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +561 -0
  212. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1032 -0
  213. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  214. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  215. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  216. onnxruntime/transformers/models/llama/__init__.py +12 -0
  217. onnxruntime/transformers/models/llama/benchmark.py +703 -0
  218. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  219. onnxruntime/transformers/models/llama/benchmark_e2e.py +606 -0
  220. onnxruntime/transformers/models/llama/convert_to_onnx.py +1027 -0
  221. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  222. onnxruntime/transformers/models/llama/llama_inputs.py +503 -0
  223. onnxruntime/transformers/models/llama/llama_parity.py +309 -0
  224. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  225. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  226. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  227. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  228. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  229. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  230. onnxruntime/transformers/models/longformer/longformer_helper.py +77 -0
  231. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  232. onnxruntime/transformers/models/phi2/convert_to_onnx.py +576 -0
  233. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  234. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  235. onnxruntime/transformers/models/sam2/benchmark_sam2.py +625 -0
  236. onnxruntime/transformers/models/sam2/convert_to_onnx.py +260 -0
  237. onnxruntime/transformers/models/sam2/image_decoder.py +273 -0
  238. onnxruntime/transformers/models/sam2/image_encoder.py +186 -0
  239. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  240. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  241. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  242. onnxruntime/transformers/models/sam2/sam2_demo.py +322 -0
  243. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +280 -0
  244. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  245. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  246. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1429 -0
  247. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  248. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +102 -0
  249. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +268 -0
  250. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  251. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1319 -0
  252. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1181 -0
  253. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +296 -0
  254. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +388 -0
  255. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  256. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  257. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  258. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +350 -0
  259. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  260. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  261. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  262. onnxruntime/transformers/models/t5/__init__.py +12 -0
  263. onnxruntime/transformers/models/t5/convert_to_onnx.py +278 -0
  264. onnxruntime/transformers/models/t5/past_helper.py +150 -0
  265. onnxruntime/transformers/models/t5/t5_decoder.py +438 -0
  266. onnxruntime/transformers/models/t5/t5_encoder.py +171 -0
  267. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +299 -0
  268. onnxruntime/transformers/models/t5/t5_helper.py +272 -0
  269. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  270. onnxruntime/transformers/models/whisper/benchmark.py +610 -0
  271. onnxruntime/transformers/models/whisper/benchmark_all.py +528 -0
  272. onnxruntime/transformers/models/whisper/convert_to_onnx.py +536 -0
  273. onnxruntime/transformers/models/whisper/whisper_chain.py +329 -0
  274. onnxruntime/transformers/models/whisper/whisper_decoder.py +402 -0
  275. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  276. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +306 -0
  277. onnxruntime/transformers/models/whisper/whisper_helper.py +524 -0
  278. onnxruntime/transformers/models/whisper/whisper_openai_helper.py +84 -0
  279. onnxruntime/transformers/onnx_exporter.py +717 -0
  280. onnxruntime/transformers/onnx_model.py +1569 -0
  281. onnxruntime/transformers/onnx_model_bart.py +142 -0
  282. onnxruntime/transformers/onnx_model_bert.py +481 -0
  283. onnxruntime/transformers/onnx_model_bert_keras.py +475 -0
  284. onnxruntime/transformers/onnx_model_bert_tf.py +589 -0
  285. onnxruntime/transformers/onnx_model_clip.py +40 -0
  286. onnxruntime/transformers/onnx_model_conformer.py +33 -0
  287. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  288. onnxruntime/transformers/onnx_model_phi.py +930 -0
  289. onnxruntime/transformers/onnx_model_sam2.py +138 -0
  290. onnxruntime/transformers/onnx_model_t5.py +791 -0
  291. onnxruntime/transformers/onnx_model_tnlr.py +227 -0
  292. onnxruntime/transformers/onnx_model_unet.py +259 -0
  293. onnxruntime/transformers/onnx_model_vae.py +43 -0
  294. onnxruntime/transformers/onnx_utils.py +55 -0
  295. onnxruntime/transformers/optimizer.py +612 -0
  296. onnxruntime/transformers/profiler.py +725 -0
  297. onnxruntime/transformers/quantize_helper.py +76 -0
  298. onnxruntime/transformers/shape_infer_helper.py +122 -0
  299. onnxruntime/transformers/shape_optimizer.py +401 -0
  300. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  301. onnxruntime_directml-1.20.0.dist-info/METADATA +187 -0
  302. onnxruntime_directml-1.20.0.dist-info/RECORD +305 -0
  303. onnxruntime_directml-1.20.0.dist-info/WHEEL +5 -0
  304. onnxruntime_directml-1.20.0.dist-info/entry_points.txt +2 -0
  305. onnxruntime_directml-1.20.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,110 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+
6
+ from logging import getLogger
7
+ from typing import Dict, List, Union
8
+
9
+ from fusion_base import Fusion
10
+ from fusion_utils import FusionUtils
11
+ from numpy import ndarray
12
+ from onnx import NodeProto, TensorProto
13
+ from onnx_model import OnnxModel
14
+
15
+ logger = getLogger(__name__)
16
+
17
+
18
+ class FusionShape(Fusion):
19
+ def __init__(self, model: OnnxModel):
20
+ super().__init__(model, "Shape", "Concat")
21
+ self.utils = FusionUtils(model)
22
+ self.shape_infer = None
23
+ self.shape_infer_done = False
24
+
25
+ def get_dimensions_from_tensor_proto(self, tensor_proto: TensorProto) -> Union[int, None]:
26
+ if tensor_proto.type.tensor_type.HasField("shape"):
27
+ return len(tensor_proto.type.tensor_type.shape.dim)
28
+ else:
29
+ return None
30
+
31
+ def get_dimensions(self, input_name: str) -> Union[int, None]:
32
+ shape = self.model.get_shape(input_name)
33
+ if shape is not None:
34
+ return len(shape)
35
+
36
+ if not self.shape_infer_done:
37
+ self.shape_infer = self.model.infer_runtime_shape(update=True)
38
+ self.shape_infer_done = True
39
+
40
+ if self.shape_infer is not None:
41
+ return self.get_dimensions_from_tensor_proto(self.shape_infer.known_vi_[input_name])
42
+
43
+ return None
44
+
45
+ def fuse(
46
+ self,
47
+ concat_node: NodeProto,
48
+ input_name_to_nodes: Dict[str, List[NodeProto]],
49
+ output_name_to_node: Dict[str, NodeProto],
50
+ ):
51
+ #
52
+ # Simplify subgraph like
53
+ #
54
+ # (2d_input)
55
+ # / \
56
+ # Shape shape
57
+ # / \
58
+ # Gather(indices=0) Gather(indices=1)
59
+ # | |
60
+ # Unsqueeze(axes=0) Unsqueeze(axes=0)
61
+ # \ /
62
+ # Concat
63
+ # |
64
+ #
65
+ # into (2d_input) --> Shape -->
66
+ #
67
+ opset_version = self.model.get_opset_version()
68
+
69
+ inputs = len(concat_node.input)
70
+ root = None
71
+ shape_output = None
72
+ for i in range(inputs):
73
+ path = self.model.match_parent_path(
74
+ concat_node,
75
+ ["Unsqueeze", "Gather", "Shape"],
76
+ [i, 0, 0],
77
+ output_name_to_node,
78
+ )
79
+ if path is None:
80
+ return
81
+
82
+ unsqueeze, gather, shape = path
83
+ if i == 0:
84
+ shape_output = shape.output[0]
85
+ if root is None:
86
+ root = shape.input[0]
87
+ if self.get_dimensions(root) != inputs:
88
+ return
89
+ elif shape.input[0] != root:
90
+ return
91
+
92
+ if not FusionUtils.check_node_attribute(unsqueeze, "axis", 0, default_value=0):
93
+ return
94
+
95
+ if opset_version < 13:
96
+ if not FusionUtils.check_node_attribute(unsqueeze, "axes", [0]):
97
+ return
98
+ else:
99
+ if not self.utils.check_node_input_value(unsqueeze, 1, [0]):
100
+ return
101
+
102
+ value = self.model.get_constant_value(gather.input[1])
103
+
104
+ if not (isinstance(value, ndarray) and value.size == 1 and value.item() == i):
105
+ return
106
+
107
+ if self.model.find_graph_output(concat_node.output[0]) is None:
108
+ self.model.replace_input_of_all_nodes(concat_node.output[0], shape_output)
109
+ self.increase_counter("Reshape")
110
+ self.prune_graph = True
@@ -0,0 +1,159 @@
1
+ import logging
2
+ from typing import Dict
3
+
4
+ from fusion_base import Fusion
5
+ from fusion_skiplayernorm import FusionSkipLayerNormalization
6
+ from onnx import helper
7
+ from onnx_model import OnnxModel
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ class FusionSimplifiedLayerNormalization(Fusion):
13
+ def __init__(self, model: OnnxModel):
14
+ super().__init__(model, "SimplifiedLayerNormalization", "Mul")
15
+
16
+ def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict):
17
+ if node.op_type != "Mul":
18
+ return
19
+
20
+ sim_ln_nodes = None
21
+ # SimplifiedLayerNorm calculation (notation from https://onnx.ai/onnx/operators/onnx__LayerNormalization.html#summary):
22
+ # DD = Pow(D, 2)
23
+ # Var = ReduceMean(DD)
24
+ # VarEps = Add(Var, epsilon)
25
+ # StdDev = Sqrt(VarEps)
26
+ # InvStdDev = Div(1, StdDev)
27
+ # Normalized = Mul(D, InvStdDev)
28
+ # NormalizedScaled = Mul(Normalized, Scale)
29
+
30
+ # SimplifiedLayerNorm
31
+ # +-------------------------------------------------------+
32
+ # | |
33
+ # Add --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Mul
34
+ # |
35
+ # node
36
+ sim_ln_nodes_1 = self.model.match_parent_path(
37
+ node,
38
+ ["Mul", "Div", "Sqrt", "Add", "ReduceMean", "Pow", "Add"],
39
+ [1, 1, 1, 0, 0, 0, 0],
40
+ )
41
+ # SimplifiedLayerNorm
42
+ # +-------------------------------------------------------+
43
+ # | |
44
+ # Gather --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Mul
45
+ # |
46
+ # node
47
+ sim_ln_nodes_2 = self.model.match_parent_path(
48
+ node,
49
+ ["Mul", "Div", "Sqrt", "Add", "ReduceMean", "Pow", "Gather"],
50
+ [1, 1, 1, 0, 0, 0, 0],
51
+ )
52
+
53
+ # For LLaMA from Microsoft custom export:
54
+ # sim_ln_nodes_3 uses a different start parent index than sim_ln_nodes_1
55
+ #
56
+ # SimplifiedLayerNorm
57
+ # +-------------------------------------------------------+
58
+ # | |
59
+ # Add --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Mul
60
+ # |
61
+ # node
62
+ sim_ln_nodes_3 = self.model.match_parent_path(
63
+ node,
64
+ ["Mul", "Div", "Sqrt", "Add", "ReduceMean", "Pow", "Add"],
65
+ [0, 1, 1, 0, 0, 0, 0],
66
+ )
67
+
68
+ # sim_ln_nodes_4 starts with a graph input instead of an Add node like sim_ln_nodes_3
69
+ #
70
+ # SimplifiedLayerNorm
71
+ # +-----------------------------------------------+
72
+ # | |
73
+ # graph_input --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul
74
+ # |
75
+ # node
76
+ sim_ln_nodes_4 = self.model.match_parent_path(
77
+ node,
78
+ ["Mul", "Div", "Sqrt", "Add", "ReduceMean", "Pow"],
79
+ [0, 1, 1, 0, 0, 0],
80
+ )
81
+
82
+ # For Gemma from Microsoft custom export, which has a Multiply after the Gather:
83
+ #
84
+ # SimplifiedLayerNorm
85
+ # +-------------------------------------------------------+
86
+ # | |
87
+ # Mul --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Mul
88
+ # |
89
+ # node
90
+ sim_ln_nodes_5 = self.model.match_parent_path(
91
+ node,
92
+ ["Mul", "Div", "Sqrt", "Add", "ReduceMean", "Pow", "Mul"],
93
+ [1, 1, 1, 0, 0, 0, 0],
94
+ )
95
+
96
+ add_node, pow_node = None, None
97
+ if sim_ln_nodes_1 is not None:
98
+ sim_ln_nodes = sim_ln_nodes_1
99
+ add_node = sim_ln_nodes[3]
100
+ pow_node = sim_ln_nodes[-2]
101
+ elif sim_ln_nodes_2 is not None:
102
+ sim_ln_nodes = sim_ln_nodes_2
103
+ add_node = sim_ln_nodes[3]
104
+ pow_node = sim_ln_nodes[-2]
105
+ elif sim_ln_nodes_3 is not None:
106
+ sim_ln_nodes = sim_ln_nodes_3
107
+ add_node = sim_ln_nodes[3]
108
+ pow_node = sim_ln_nodes[-2]
109
+ elif sim_ln_nodes_4 is not None:
110
+ sim_ln_nodes = sim_ln_nodes_4
111
+ add_node = sim_ln_nodes[3]
112
+ pow_node = sim_ln_nodes[-1]
113
+ # Verify that parent input to Pow node is graph_input
114
+ if pow_node.input[0] not in self.model.get_graphs_input_names():
115
+ return
116
+ elif sim_ln_nodes_5 is not None:
117
+ sim_ln_nodes = sim_ln_nodes_5
118
+ add_node = sim_ln_nodes[3]
119
+ pow_node = sim_ln_nodes[-2]
120
+ else:
121
+ return
122
+
123
+ layernorm_weight_index = 1 if sim_ln_nodes in (sim_ln_nodes_3, sim_ln_nodes_4) else 0
124
+ starts_with_graph_input = sim_ln_nodes == sim_ln_nodes_4
125
+
126
+ if self.model.find_constant_input(pow_node, 2.0) != 1:
127
+ return
128
+
129
+ root_input = pow_node.input[0]
130
+ if root_input != sim_ln_nodes[0].input[0]:
131
+ return
132
+
133
+ i, add_weight = self.model.get_constant_input(add_node)
134
+ if add_weight is None or add_weight <= 0 or add_weight > 1.0e-4:
135
+ logger.warning(f"epsilon value is not expected: {add_weight}")
136
+ return
137
+
138
+ self.nodes_to_remove.extend(sim_ln_nodes[:-1] if not starts_with_graph_input else sim_ln_nodes)
139
+ self.nodes_to_remove.append(node)
140
+
141
+ normalize_node = helper.make_node(
142
+ "SimplifiedLayerNormalization",
143
+ inputs=[root_input, node.input[layernorm_weight_index]],
144
+ outputs=[node.output[0]],
145
+ name=self.model.create_node_name("SimplifiedLayerNormalization", name_prefix="LayerNorm"),
146
+ )
147
+ normalize_node.attribute.extend([helper.make_attribute("epsilon", float(add_weight))])
148
+ normalize_node.attribute.extend([helper.make_attribute("axis", -1)])
149
+ normalize_node.attribute.extend([helper.make_attribute("stash_type", 1)])
150
+ self.nodes_to_add.append(normalize_node)
151
+ self.node_name_to_graph_name[normalize_node.name] = self.this_graph_name
152
+
153
+
154
+ class FusionSkipSimplifiedLayerNormalization(FusionSkipLayerNormalization):
155
+ def __init__(self, model: OnnxModel):
156
+ super().__init__(model, "SkipSimplifiedLayerNormalization", "SimplifiedLayerNormalization")
157
+
158
+ def fuse(self, node, input_name_to_nodes, output_name_to_node):
159
+ super().fuse(node, input_name_to_nodes, output_name_to_node)
@@ -0,0 +1,255 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+ from logging import getLogger
6
+ from typing import List
7
+
8
+ from fusion_base import Fusion
9
+ from fusion_utils import NumpyHelper
10
+ from onnx import helper
11
+ from onnx_model import OnnxModel
12
+
13
+ logger = getLogger(__name__)
14
+
15
+
16
+ class FusionSkipGroupNorm(Fusion):
17
+ """
18
+ Fuse Add + GroupNorm into one node: SkipGroupNorm.
19
+ """
20
+
21
+ def __init__(self, model: OnnxModel):
22
+ super().__init__(model, "SkipGroupNorm", "GroupNorm")
23
+ # Update shape inference is needed since other fusions might add new edge which does not have shape info yet.
24
+ self.shape_infer_helper = self.model.infer_runtime_shape(update=True)
25
+
26
+ if self.shape_infer_helper is None:
27
+ logger.warning("SkipGroupNorm fusion will be skipped since symbolic shape inference disabled or failed.")
28
+
29
+ def create_transpose_node(self, input_name: str, perm: List[int], output_name=None):
30
+ """Append a Transpose node after an input"""
31
+ node_name = self.model.create_node_name("Transpose")
32
+ if output_name is None:
33
+ output_name = node_name + "_out" + "-" + input_name
34
+ transpose_node = helper.make_node("Transpose", inputs=[input_name], outputs=[output_name], name=node_name)
35
+ transpose_node.attribute.extend([helper.make_attribute("perm", perm)])
36
+ return transpose_node
37
+
38
+ def get_skip_index(self, add, is_channel_last: bool):
39
+ """Add has two inputs. This classifies which input is skip based on shape info (skip allows broadcast)."""
40
+ skip = -1
41
+ broadcast = False
42
+
43
+ assert self.shape_infer_helper is not None
44
+ shape_a = self.shape_infer_helper.get_edge_shape(add.input[0])
45
+ shape_b = self.shape_infer_helper.get_edge_shape(add.input[1])
46
+ assert shape_a is not None and shape_b is not None
47
+
48
+ if len(shape_a) == 4 and len(shape_b) == 4:
49
+ if shape_a == shape_b:
50
+ skip = 1
51
+ else:
52
+ c = 3 if is_channel_last else 1
53
+ h = 1 if is_channel_last else 2
54
+ w = 2 if is_channel_last else 3
55
+ if shape_a[0] == shape_b[0] and shape_a[c] == shape_b[c]:
56
+ if shape_b[h] == 1 and shape_b[w] == 1:
57
+ skip = 1
58
+ broadcast = True
59
+ elif shape_a[h] == 1 and shape_a[w] == 1:
60
+ skip = 0
61
+ broadcast = True
62
+
63
+ if skip < 0:
64
+ logger.debug(
65
+ "skip SkipGroupNorm fusion since shape of Add inputs (%s, %s) are not expected",
66
+ add.input[0],
67
+ add.input[1],
68
+ )
69
+ return skip, broadcast
70
+
71
+ def has_multiple_consumers(self, output_name, input_name_to_nodes):
72
+ """Whether an output has multiple consumers (like graph output or more than one children nodes)"""
73
+ return self.model.find_graph_output(output_name) is not None or (
74
+ output_name in input_name_to_nodes and len(input_name_to_nodes[output_name]) > 1
75
+ )
76
+
77
+ def remove_if_safe(self, node, input_name_to_nodes):
78
+ """Remove a node if it is safe (only one children, and not graph output)"""
79
+ if not self.has_multiple_consumers(node.output[0], input_name_to_nodes):
80
+ self.nodes_to_remove.extend([node])
81
+
82
+ def is_bias_1d(self, bias_name: str):
83
+ """Whether bias is an initializer of one dimension"""
84
+ initializer = self.model.get_initializer(bias_name)
85
+ if initializer is None:
86
+ return False
87
+
88
+ bias_weight = NumpyHelper.to_array(initializer)
89
+ if bias_weight is None:
90
+ logger.debug("Bias weight not found")
91
+ return False
92
+
93
+ if len(bias_weight.shape) != 1:
94
+ logger.debug("Bias weight is not 1D")
95
+ return False
96
+ return True
97
+
98
+ def match_bias_path(self, node, input_name_to_nodes, output_name_to_node):
99
+ """
100
+ Match the bias graph pattern from an Transpose node after Reshape node like in below example.
101
+ It checks whether the bias is 1D initializer. If so, remove Add and redirect MatMul output to Reshape.
102
+ """
103
+ # Before Fusion:
104
+ # MatMul (bias)
105
+ # \ / (shape)
106
+ # Add /
107
+ # \ /
108
+ # (a) Reshape
109
+ # \ |
110
+ # Transpose([0, 3, 1, 2]) Transpose([0, 3, 1, 2]) --- the start node, this func only handles the above nodes.
111
+ # \ /
112
+ # Add
113
+ # / \
114
+ # (c) Transpose([0,2,3,1])
115
+ # |
116
+ # GroupNorm
117
+ # |
118
+ # (d)
119
+ #
120
+ # After Fusion (the nodes below Reshape is handled in the fuse function):
121
+ # MatMul (shape)
122
+ # \ /
123
+ # (a) Reshape
124
+ # \ /
125
+ # SkipGroupNorm
126
+ # / \
127
+ # (d) Transpose([0, 3, 1, 2])
128
+ # \
129
+ # (c)
130
+
131
+ add_input_index = []
132
+ bias_nodes = self.model.match_parent_path(
133
+ node, ["Reshape", "Add", "MatMul"], [0, 0, None], output_name_to_node, add_input_index
134
+ )
135
+ if bias_nodes is None:
136
+ return None
137
+
138
+ (reshape, add_bias, matmul) = bias_nodes
139
+ bias = bias_nodes[1].input[1 - add_input_index[0]]
140
+ if not self.is_bias_1d(bias):
141
+ return None
142
+
143
+ reshape.input[0] = matmul.output[0]
144
+ self.remove_if_safe(add_bias, input_name_to_nodes)
145
+
146
+ return bias
147
+
148
+ def match_transpose_from_nhwc(self, output_name, input_name_to_nodes, output_name_to_node):
149
+ """Match whether an output is from a Transpose(perm=[0,3,1,2]) node."""
150
+ parent = output_name_to_node.get(output_name, None)
151
+ if parent is not None and parent.op_type == "Transpose":
152
+ permutation = OnnxModel.get_node_attribute(parent, "perm")
153
+ if permutation == [0, 3, 1, 2]:
154
+ self.remove_if_safe(parent, input_name_to_nodes)
155
+ return parent
156
+ return None
157
+
158
+ def fuse(self, node, input_name_to_nodes, output_name_to_node):
159
+ # This fusion requires shape information, so skip it if shape is not available.
160
+ if self.shape_infer_helper is None:
161
+ return
162
+
163
+ # Before Fusion:
164
+ # (a) (b)
165
+ # \ /
166
+ # Add
167
+ # /\
168
+ # (c) Transpose([0,2,3,1])
169
+ # \
170
+ # GroupNorm
171
+ # |
172
+ # (d)
173
+ #
174
+ # After Fusion:
175
+ # (a) (b)
176
+ # \ /
177
+ # Transpose([0,2,3,1]) Transpose([0,2,3,1])
178
+ # \ /
179
+ # SkipGroupNorm
180
+ # / \
181
+ # / Transpose([0, 3, 1, 2])
182
+ # / \
183
+ # (d) (c)
184
+ nodes = self.model.match_parent_path(node, ["Transpose", "Add"], [0, 0], output_name_to_node)
185
+ if nodes is None:
186
+ return
187
+
188
+ (transpose, add) = nodes
189
+ if transpose in self.nodes_to_remove or add in self.nodes_to_remove:
190
+ return
191
+
192
+ if self.has_multiple_consumers(transpose.output[0], input_name_to_nodes):
193
+ return
194
+
195
+ permutation = OnnxModel.get_node_attribute(transpose, "perm")
196
+ if permutation != [0, 2, 3, 1]:
197
+ return
198
+
199
+ inputs = []
200
+ bias = None
201
+ for i in range(2):
202
+ matched_transpose = self.match_transpose_from_nhwc(add.input[i], input_name_to_nodes, output_name_to_node)
203
+ if matched_transpose:
204
+ # When there is an Transpose node before Add (see examples in match_bias_path), we do not need to
205
+ # insert another Transpose node. The existing Transpose node will be removed in prune_graph if it
206
+ # has only one consumer.
207
+ inputs.append(matched_transpose.input[0])
208
+ # See whether it match bias pattern.
209
+ if bias is None:
210
+ bias = self.match_bias_path(matched_transpose, input_name_to_nodes, output_name_to_node)
211
+ else:
212
+ # Otherwise, insert a Transpose node before Add.
213
+ new_transpose = self.create_transpose_node(add.input[i], [0, 2, 3, 1])
214
+ self.model.add_node(new_transpose, self.this_graph_name)
215
+ inputs.append(new_transpose.output[0])
216
+
217
+ skip, broadcast = self.get_skip_index(add, is_channel_last=False)
218
+ if skip < 0:
219
+ return
220
+
221
+ inputs = [inputs[1 - skip], node.input[1], node.input[2], inputs[skip]]
222
+ if bias:
223
+ inputs = [*inputs, bias]
224
+
225
+ outputs = node.output
226
+
227
+ new_node_name = self.model.create_node_name(self.fused_op_type, name_prefix="SkipGroupNorm")
228
+ if self.has_multiple_consumers(add.output[0], input_name_to_nodes):
229
+ add_out_name = new_node_name + "_add_out"
230
+ outputs.append(add_out_name)
231
+
232
+ # Insert a Transpose node after add output.
233
+ add_out_transpose = self.create_transpose_node(add_out_name, [0, 3, 1, 2], add.output[0])
234
+ self.model.add_node(add_out_transpose, self.this_graph_name)
235
+
236
+ skip_group_norm = helper.make_node(
237
+ self.fused_op_type,
238
+ inputs=inputs,
239
+ outputs=outputs,
240
+ name=new_node_name,
241
+ )
242
+ skip_group_norm.domain = "com.microsoft"
243
+
244
+ self.increase_counter(
245
+ f"SkipGroupNorm(add_out={int(len(outputs) > 1)} bias={int(bias is not None)} broadcast={int(broadcast)})"
246
+ )
247
+
248
+ # Pass attributes from GroupNorm node to SkipGroupNorm
249
+ for att in node.attribute:
250
+ skip_group_norm.attribute.extend([att])
251
+
252
+ self.nodes_to_remove.extend([add, transpose, node])
253
+ self.nodes_to_add.append(skip_group_norm)
254
+ self.node_name_to_graph_name[skip_group_norm.name] = self.this_graph_name
255
+ self.prune_graph = True