onnxruntime-directml 1.20.0__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (305) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6508 -0
  4. onnxruntime/__init__.py +78 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +174 -0
  7. onnxruntime/backend/backend_rep.py +53 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  13. onnxruntime/capi/onnxruntime.dll +0 -0
  14. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  15. onnxruntime/capi/onnxruntime_inference_collection.py +1108 -0
  16. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  17. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  18. onnxruntime/capi/onnxruntime_validation.py +150 -0
  19. onnxruntime/capi/version_info.py +2 -0
  20. onnxruntime/datasets/__init__.py +17 -0
  21. onnxruntime/datasets/logreg_iris.onnx +0 -0
  22. onnxruntime/datasets/mul_1.onnx +0 -0
  23. onnxruntime/datasets/sigmoid.onnx +13 -0
  24. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  27. onnxruntime/quantization/__init__.py +16 -0
  28. onnxruntime/quantization/base_quantizer.py +532 -0
  29. onnxruntime/quantization/calibrate.py +1245 -0
  30. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  31. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  32. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  33. onnxruntime/quantization/execution_providers/qnn/preprocess.py +307 -0
  34. onnxruntime/quantization/execution_providers/qnn/quant_config.py +387 -0
  35. onnxruntime/quantization/fusions/__init__.py +3 -0
  36. onnxruntime/quantization/fusions/fusion.py +311 -0
  37. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  38. onnxruntime/quantization/fusions/fusion_layernorm.py +135 -0
  39. onnxruntime/quantization/matmul_4bits_quantizer.py +1480 -0
  40. onnxruntime/quantization/matmul_bnb4_quantizer.py +240 -0
  41. onnxruntime/quantization/onnx_model.py +580 -0
  42. onnxruntime/quantization/onnx_quantizer.py +1008 -0
  43. onnxruntime/quantization/operators/__init__.py +2 -0
  44. onnxruntime/quantization/operators/activation.py +119 -0
  45. onnxruntime/quantization/operators/argmax.py +18 -0
  46. onnxruntime/quantization/operators/attention.py +73 -0
  47. onnxruntime/quantization/operators/base_operator.py +26 -0
  48. onnxruntime/quantization/operators/binary_op.py +72 -0
  49. onnxruntime/quantization/operators/concat.py +62 -0
  50. onnxruntime/quantization/operators/conv.py +258 -0
  51. onnxruntime/quantization/operators/direct_q8.py +78 -0
  52. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  53. onnxruntime/quantization/operators/gather.py +64 -0
  54. onnxruntime/quantization/operators/gavgpool.py +62 -0
  55. onnxruntime/quantization/operators/gemm.py +166 -0
  56. onnxruntime/quantization/operators/lstm.py +117 -0
  57. onnxruntime/quantization/operators/matmul.py +231 -0
  58. onnxruntime/quantization/operators/maxpool.py +34 -0
  59. onnxruntime/quantization/operators/norm.py +40 -0
  60. onnxruntime/quantization/operators/pad.py +100 -0
  61. onnxruntime/quantization/operators/pooling.py +67 -0
  62. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  63. onnxruntime/quantization/operators/resize.py +34 -0
  64. onnxruntime/quantization/operators/softmax.py +74 -0
  65. onnxruntime/quantization/operators/split.py +63 -0
  66. onnxruntime/quantization/operators/where.py +87 -0
  67. onnxruntime/quantization/preprocess.py +141 -0
  68. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  69. onnxruntime/quantization/qdq_quantizer.py +1187 -0
  70. onnxruntime/quantization/quant_utils.py +891 -0
  71. onnxruntime/quantization/quantize.py +748 -0
  72. onnxruntime/quantization/registry.py +106 -0
  73. onnxruntime/quantization/shape_inference.py +187 -0
  74. onnxruntime/quantization/tensor_quant_overrides.py +516 -0
  75. onnxruntime/tools/__init__.py +10 -0
  76. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  77. onnxruntime/tools/convert_onnx_models_to_ort.py +377 -0
  78. onnxruntime/tools/file_utils.py +46 -0
  79. onnxruntime/tools/logger.py +11 -0
  80. onnxruntime/tools/make_dynamic_shape_fixed.py +72 -0
  81. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  82. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +33 -0
  83. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  84. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  85. onnxruntime/tools/mobile_helpers/usability_checker.py +739 -0
  86. onnxruntime/tools/offline_tuning.py +169 -0
  87. onnxruntime/tools/onnx_model_utils.py +413 -0
  88. onnxruntime/tools/onnx_randomizer.py +85 -0
  89. onnxruntime/tools/onnxruntime_test.py +164 -0
  90. onnxruntime/tools/optimize_onnx_model.py +55 -0
  91. onnxruntime/tools/ort_format_model/__init__.py +25 -0
  92. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +663 -0
  93. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  94. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  95. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  96. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  97. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  98. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  99. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  100. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  140. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  141. onnxruntime/tools/ort_format_model/types.py +84 -0
  142. onnxruntime/tools/ort_format_model/utils.py +62 -0
  143. onnxruntime/tools/pytorch_export_contrib_ops.py +108 -0
  144. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  145. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  146. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  147. onnxruntime/tools/reduced_build_config_parser.py +202 -0
  148. onnxruntime/tools/symbolic_shape_infer.py +3016 -0
  149. onnxruntime/tools/update_onnx_opset.py +31 -0
  150. onnxruntime/transformers/__init__.py +8 -0
  151. onnxruntime/transformers/affinity_helper.py +40 -0
  152. onnxruntime/transformers/benchmark.py +944 -0
  153. onnxruntime/transformers/benchmark_helper.py +646 -0
  154. onnxruntime/transformers/bert_perf_test.py +634 -0
  155. onnxruntime/transformers/bert_test_data.py +642 -0
  156. onnxruntime/transformers/compare_bert_results.py +246 -0
  157. onnxruntime/transformers/constants.py +47 -0
  158. onnxruntime/transformers/convert_generation.py +3124 -0
  159. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  160. onnxruntime/transformers/convert_to_packing_mode.py +387 -0
  161. onnxruntime/transformers/dynamo_onnx_helper.py +104 -0
  162. onnxruntime/transformers/float16.py +501 -0
  163. onnxruntime/transformers/fusion_attention.py +1235 -0
  164. onnxruntime/transformers/fusion_attention_clip.py +257 -0
  165. onnxruntime/transformers/fusion_attention_sam2.py +534 -0
  166. onnxruntime/transformers/fusion_attention_unet.py +1304 -0
  167. onnxruntime/transformers/fusion_attention_vae.py +301 -0
  168. onnxruntime/transformers/fusion_bart_attention.py +640 -0
  169. onnxruntime/transformers/fusion_base.py +137 -0
  170. onnxruntime/transformers/fusion_bias_add.py +58 -0
  171. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  172. onnxruntime/transformers/fusion_biassplitgelu.py +111 -0
  173. onnxruntime/transformers/fusion_conformer_attention.py +143 -0
  174. onnxruntime/transformers/fusion_embedlayer.py +811 -0
  175. onnxruntime/transformers/fusion_fastgelu.py +360 -0
  176. onnxruntime/transformers/fusion_gelu.py +259 -0
  177. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  178. onnxruntime/transformers/fusion_gemmfastgelu.py +122 -0
  179. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  180. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  181. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  182. onnxruntime/transformers/fusion_group_norm.py +179 -0
  183. onnxruntime/transformers/fusion_layernorm.py +465 -0
  184. onnxruntime/transformers/fusion_nhwc_conv.py +100 -0
  185. onnxruntime/transformers/fusion_options.py +340 -0
  186. onnxruntime/transformers/fusion_qordered_attention.py +421 -0
  187. onnxruntime/transformers/fusion_qordered_gelu.py +119 -0
  188. onnxruntime/transformers/fusion_qordered_layernorm.py +123 -0
  189. onnxruntime/transformers/fusion_qordered_matmul.py +217 -0
  190. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  191. onnxruntime/transformers/fusion_reshape.py +173 -0
  192. onnxruntime/transformers/fusion_rotary_attention.py +1592 -0
  193. onnxruntime/transformers/fusion_shape.py +110 -0
  194. onnxruntime/transformers/fusion_simplified_layernorm.py +159 -0
  195. onnxruntime/transformers/fusion_skip_group_norm.py +255 -0
  196. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  197. onnxruntime/transformers/fusion_transpose.py +168 -0
  198. onnxruntime/transformers/fusion_utils.py +307 -0
  199. onnxruntime/transformers/huggingface_models.py +167 -0
  200. onnxruntime/transformers/import_utils.py +20 -0
  201. onnxruntime/transformers/io_binding_helper.py +442 -0
  202. onnxruntime/transformers/large_model_exporter.py +395 -0
  203. onnxruntime/transformers/machine_info.py +221 -0
  204. onnxruntime/transformers/metrics.py +164 -0
  205. onnxruntime/transformers/models/bart/__init__.py +12 -0
  206. onnxruntime/transformers/models/bart/export.py +98 -0
  207. onnxruntime/transformers/models/bert/__init__.py +12 -0
  208. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  209. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  210. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  211. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +561 -0
  212. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1032 -0
  213. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  214. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  215. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  216. onnxruntime/transformers/models/llama/__init__.py +12 -0
  217. onnxruntime/transformers/models/llama/benchmark.py +703 -0
  218. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  219. onnxruntime/transformers/models/llama/benchmark_e2e.py +606 -0
  220. onnxruntime/transformers/models/llama/convert_to_onnx.py +1027 -0
  221. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  222. onnxruntime/transformers/models/llama/llama_inputs.py +503 -0
  223. onnxruntime/transformers/models/llama/llama_parity.py +309 -0
  224. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  225. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  226. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  227. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  228. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  229. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  230. onnxruntime/transformers/models/longformer/longformer_helper.py +77 -0
  231. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  232. onnxruntime/transformers/models/phi2/convert_to_onnx.py +576 -0
  233. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  234. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  235. onnxruntime/transformers/models/sam2/benchmark_sam2.py +625 -0
  236. onnxruntime/transformers/models/sam2/convert_to_onnx.py +260 -0
  237. onnxruntime/transformers/models/sam2/image_decoder.py +273 -0
  238. onnxruntime/transformers/models/sam2/image_encoder.py +186 -0
  239. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  240. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  241. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  242. onnxruntime/transformers/models/sam2/sam2_demo.py +322 -0
  243. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +280 -0
  244. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  245. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  246. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1429 -0
  247. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  248. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +102 -0
  249. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +268 -0
  250. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  251. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1319 -0
  252. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1181 -0
  253. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +296 -0
  254. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +388 -0
  255. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  256. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  257. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  258. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +350 -0
  259. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  260. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  261. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  262. onnxruntime/transformers/models/t5/__init__.py +12 -0
  263. onnxruntime/transformers/models/t5/convert_to_onnx.py +278 -0
  264. onnxruntime/transformers/models/t5/past_helper.py +150 -0
  265. onnxruntime/transformers/models/t5/t5_decoder.py +438 -0
  266. onnxruntime/transformers/models/t5/t5_encoder.py +171 -0
  267. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +299 -0
  268. onnxruntime/transformers/models/t5/t5_helper.py +272 -0
  269. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  270. onnxruntime/transformers/models/whisper/benchmark.py +610 -0
  271. onnxruntime/transformers/models/whisper/benchmark_all.py +528 -0
  272. onnxruntime/transformers/models/whisper/convert_to_onnx.py +536 -0
  273. onnxruntime/transformers/models/whisper/whisper_chain.py +329 -0
  274. onnxruntime/transformers/models/whisper/whisper_decoder.py +402 -0
  275. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  276. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +306 -0
  277. onnxruntime/transformers/models/whisper/whisper_helper.py +524 -0
  278. onnxruntime/transformers/models/whisper/whisper_openai_helper.py +84 -0
  279. onnxruntime/transformers/onnx_exporter.py +717 -0
  280. onnxruntime/transformers/onnx_model.py +1569 -0
  281. onnxruntime/transformers/onnx_model_bart.py +142 -0
  282. onnxruntime/transformers/onnx_model_bert.py +481 -0
  283. onnxruntime/transformers/onnx_model_bert_keras.py +475 -0
  284. onnxruntime/transformers/onnx_model_bert_tf.py +589 -0
  285. onnxruntime/transformers/onnx_model_clip.py +40 -0
  286. onnxruntime/transformers/onnx_model_conformer.py +33 -0
  287. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  288. onnxruntime/transformers/onnx_model_phi.py +930 -0
  289. onnxruntime/transformers/onnx_model_sam2.py +138 -0
  290. onnxruntime/transformers/onnx_model_t5.py +791 -0
  291. onnxruntime/transformers/onnx_model_tnlr.py +227 -0
  292. onnxruntime/transformers/onnx_model_unet.py +259 -0
  293. onnxruntime/transformers/onnx_model_vae.py +43 -0
  294. onnxruntime/transformers/onnx_utils.py +55 -0
  295. onnxruntime/transformers/optimizer.py +612 -0
  296. onnxruntime/transformers/profiler.py +725 -0
  297. onnxruntime/transformers/quantize_helper.py +76 -0
  298. onnxruntime/transformers/shape_infer_helper.py +122 -0
  299. onnxruntime/transformers/shape_optimizer.py +401 -0
  300. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  301. onnxruntime_directml-1.20.0.dist-info/METADATA +187 -0
  302. onnxruntime_directml-1.20.0.dist-info/RECORD +305 -0
  303. onnxruntime_directml-1.20.0.dist-info/WHEEL +5 -0
  304. onnxruntime_directml-1.20.0.dist-info/entry_points.txt +2 -0
  305. onnxruntime_directml-1.20.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,301 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+ from logging import getLogger
6
+ from typing import Tuple, Union
7
+
8
+ import numpy as np
9
+ from fusion_base import Fusion
10
+ from onnx import NodeProto, TensorProto, helper, numpy_helper
11
+ from onnx_model import OnnxModel
12
+
13
+ logger = getLogger(__name__)
14
+
15
+
16
+ class FusionAttentionVae(Fusion):
17
+ """
18
+ Fuse Attention subgraph of Vae Decoder into one Attention node.
19
+ """
20
+
21
+ def __init__(self, model: OnnxModel, hidden_size: int, num_heads: int):
22
+ super().__init__(model, "Attention", ["Softmax"])
23
+ self.hidden_size = hidden_size
24
+ self.num_heads = num_heads
25
+
26
+ # Flags to show warning only once
27
+ self.num_heads_warning = True
28
+ self.hidden_size_warning = True
29
+
30
+ def get_num_heads_and_hidden_size(self, reshape_q: NodeProto, add_q: NodeProto) -> Tuple[int, int]:
31
+ """Detect num_heads and hidden_size from a reshape node.
32
+
33
+ Args:
34
+ reshape_q (NodeProto): reshape node for Q
35
+ add_q (NodeProto): add node for Q
36
+
37
+ Returns:
38
+ Tuple[int, int]: num_heads and hidden_size
39
+ """
40
+ concat = self.model.get_parent(reshape_q, 1)
41
+ if concat is None or len(concat.input) != 4:
42
+ return self.num_heads, self.hidden_size # Fall back to user specified value
43
+
44
+ value = self.model.get_constant_value(concat.input[2])
45
+ if not (value is not None and isinstance(value, np.ndarray) and value.size == 1):
46
+ return self.num_heads, self.hidden_size # Fall back to user specified value
47
+ num_heads = int(value)
48
+ if num_heads <= 0:
49
+ return self.num_heads, self.hidden_size # Fall back to user specified value
50
+
51
+ _, bias = self.model.get_constant_input(add_q)
52
+ if (bias is None) or (not isinstance(bias, np.ndarray)) or bias.ndim != 1:
53
+ return self.num_heads, self.hidden_size # Fall back to user specified value
54
+
55
+ hidden_size = bias.shape[0]
56
+
57
+ if self.num_heads > 0 and num_heads != self.num_heads:
58
+ if self.num_heads_warning:
59
+ logger.warning(
60
+ "Detected number of attention heads is %d. Ignore --num_heads %d", num_heads, self.num_heads
61
+ )
62
+ self.num_heads_warning = False # Do not show the warning more than once
63
+
64
+ if self.hidden_size > 0 and hidden_size != self.hidden_size:
65
+ if self.hidden_size_warning:
66
+ logger.warning("Detected hidden size is %d. Ignore --hidden_size %d", hidden_size, self.hidden_size)
67
+ self.hidden_size_warning = False # Do not show the warning more than once
68
+
69
+ return num_heads, hidden_size
70
+
71
+ def create_attention_node(
72
+ self,
73
+ q_matmul: NodeProto,
74
+ q_add: NodeProto,
75
+ k_matmul: NodeProto,
76
+ k_add: NodeProto,
77
+ v_matmul: NodeProto,
78
+ v_add: NodeProto,
79
+ num_heads: int,
80
+ hidden_size: int,
81
+ input_name: str,
82
+ output_name: str,
83
+ ) -> Union[NodeProto, None]:
84
+ """Create an Attention node.
85
+
86
+ Args:
87
+ q_matmul (NodeProto): MatMul node in fully connection for Q
88
+ q_add (NodeProto): Add bias node in fully connection for Q
89
+ k_matmul (NodeProto): MatMul node in fully connection for K
90
+ k_add (NodeProto): Add bias node in fully connection for K
91
+ v_matmul (NodeProto): MatMul node in fully connection for V
92
+ v_add (NodeProto): Add bias node in fully connection for V
93
+ num_heads (int): number of attention heads. If a model is pruned, it is the number of heads after pruning.
94
+ hidden_size (int): hidden dimension. If a model is pruned, it is the hidden dimension after pruning.
95
+ input_name (str): input name
96
+ output_name (str): output name
97
+
98
+ Returns:
99
+ Union[NodeProto, None]: the node created or None if failed.
100
+ """
101
+ if q_matmul.input[0] != input_name or k_matmul.input[0] != input_name or v_matmul.input[0] != input_name:
102
+ logger.debug(
103
+ "For self attention, input hidden state for q and k/v shall be same. Got %s, %s, %s",
104
+ q_matmul.input[0],
105
+ k_matmul.input[0],
106
+ v_matmul.input[0],
107
+ )
108
+ return None
109
+
110
+ if hidden_size > 0 and (hidden_size % num_heads) != 0:
111
+ logger.debug("input hidden size %d is not a multiple of num of heads %d", hidden_size, num_heads)
112
+ return None
113
+
114
+ q_weight_tensor = self.model.get_initializer(q_matmul.input[1])
115
+ k_weight_tensor = self.model.get_initializer(k_matmul.input[1])
116
+ v_weight_tensor = self.model.get_initializer(v_matmul.input[1])
117
+ if not (q_weight_tensor and k_weight_tensor and v_weight_tensor):
118
+ return None
119
+
120
+ q_bias_tensor = self.model.get_initializer(q_add.input[1]) or self.model.get_initializer(q_add.input[0])
121
+ k_bias_tensor = self.model.get_initializer(k_add.input[1]) or self.model.get_initializer(k_add.input[0])
122
+ v_bias_tensor = self.model.get_initializer(v_add.input[1]) or self.model.get_initializer(v_add.input[0])
123
+
124
+ q_bias = numpy_helper.to_array(q_bias_tensor)
125
+ k_bias = numpy_helper.to_array(k_bias_tensor)
126
+ v_bias = numpy_helper.to_array(v_bias_tensor)
127
+
128
+ q_bias_shape = np.prod(q_bias.shape)
129
+ k_bias_shape = np.prod(k_bias.shape)
130
+ v_bias_shape = np.prod(v_bias.shape)
131
+
132
+ # Sometimes weights are stored in fp16
133
+ if q_weight_tensor.data_type == 10:
134
+ logger.debug("weights are in fp16. Please run fp16 conversion after optimization")
135
+ return None
136
+
137
+ q_weight = numpy_helper.to_array(q_weight_tensor)
138
+ k_weight = numpy_helper.to_array(k_weight_tensor)
139
+ v_weight = numpy_helper.to_array(v_weight_tensor)
140
+
141
+ # assert q and k have same shape as expected
142
+ if q_weight.shape != k_weight.shape or q_weight.shape != v_weight.shape:
143
+ return None
144
+
145
+ qw_in_size = q_weight.shape[0]
146
+ kw_in_size = k_weight.shape[0]
147
+ vw_in_size = v_weight.shape[0]
148
+
149
+ assert qw_in_size == kw_in_size and kw_in_size == vw_in_size
150
+
151
+ if hidden_size > 0 and hidden_size != qw_in_size:
152
+ raise ValueError(
153
+ f"Input hidden size ({hidden_size}) is not same as weight dimension of q,k,v ({qw_in_size}). "
154
+ "Please provide a correct input hidden size or pass in 0"
155
+ )
156
+
157
+ # All the matrices can have the same shape or q, k matrics can have the same shape with v being different
158
+ # For 2d weights, the shapes would be [in_size, out_size].
159
+ # For 3d weights, shape would be [in_size, a, b] where a*b = out_size
160
+ qw_out_size = np.prod(q_weight.shape[1:])
161
+
162
+ qkv_weight = np.stack((q_weight, k_weight, v_weight), axis=1)
163
+ qkv_weight_dim = 3 * int(qw_out_size)
164
+
165
+ attention_node_name = self.model.create_node_name("Attention")
166
+
167
+ assert q_bias_shape == k_bias_shape == v_bias_shape
168
+
169
+ qkv_bias_dim = 0
170
+ qkv_bias = np.stack((q_bias, k_bias, v_bias), axis=0)
171
+ qkv_bias_dim = 3 * q_bias_shape
172
+
173
+ self.add_initializer(
174
+ name=attention_node_name + "_qkv_weight",
175
+ data_type=TensorProto.FLOAT,
176
+ dims=[qw_in_size, qkv_weight_dim],
177
+ vals=qkv_weight,
178
+ )
179
+
180
+ # No bias, use zeros
181
+ qkv_bias = np.zeros([3, hidden_size], dtype=np.float32)
182
+ qkv_bias_dim = 3 * hidden_size
183
+
184
+ self.add_initializer(
185
+ name=attention_node_name + "_qkv_bias",
186
+ data_type=TensorProto.FLOAT,
187
+ dims=[qkv_bias_dim],
188
+ vals=qkv_bias,
189
+ )
190
+
191
+ attention_inputs = [
192
+ input_name,
193
+ attention_node_name + "_qkv_weight",
194
+ attention_node_name + "_qkv_bias",
195
+ ]
196
+
197
+ attention_node = helper.make_node(
198
+ "Attention",
199
+ inputs=attention_inputs,
200
+ outputs=[output_name],
201
+ name=attention_node_name,
202
+ )
203
+ attention_node.domain = "com.microsoft"
204
+ attention_node.attribute.extend([helper.make_attribute("num_heads", num_heads)])
205
+
206
+ self.increase_counter("Attention (self attention)")
207
+ return attention_node
208
+
209
+ def fuse(self, softmax_node, input_name_to_nodes, output_name_to_node):
210
+ matmul_qkv = self.model.find_first_child_by_type(softmax_node, "MatMul", input_name_to_nodes, recursive=False)
211
+ if matmul_qkv is None:
212
+ return
213
+
214
+ reshape_qkv = self.model.find_first_child_by_type(matmul_qkv, "Reshape", input_name_to_nodes, recursive=False)
215
+ if reshape_qkv is None:
216
+ return
217
+
218
+ transpose_qkv = self.model.find_first_child_by_type(
219
+ reshape_qkv, "Transpose", input_name_to_nodes, recursive=False
220
+ )
221
+ if transpose_qkv is None:
222
+ return
223
+
224
+ reshape_out = self.model.find_first_child_by_type(
225
+ transpose_qkv, "Reshape", input_name_to_nodes, recursive=False
226
+ )
227
+ if reshape_out is None:
228
+ return
229
+
230
+ matmul_out = self.model.find_first_child_by_type(reshape_out, "MatMul", input_name_to_nodes, recursive=False)
231
+ if matmul_out is None:
232
+ return
233
+
234
+ add_out = self.model.find_first_child_by_type(matmul_out, "Add", input_name_to_nodes, recursive=False)
235
+ if add_out is None:
236
+ return
237
+
238
+ transpose_out = self.model.find_first_child_by_type(add_out, "Transpose", input_name_to_nodes, recursive=False)
239
+ if transpose_out is None:
240
+ return
241
+
242
+ v_nodes = self.model.match_parent_path(
243
+ matmul_qkv, ["Reshape", "Transpose", "Reshape", "Add", "MatMul"], [1, 0, 0, 0, None]
244
+ )
245
+ if v_nodes is None:
246
+ logger.debug("fuse_attention: failed to match v path")
247
+ return
248
+ (_, _, _, add_v, matmul_v) = v_nodes
249
+
250
+ qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Add", "Mul", "MatMul"], [0, 0, 0, 0])
251
+ if qk_nodes is not None:
252
+ (_softmax_qk, _add_zero, _mul_qk, matmul_qk) = qk_nodes
253
+ else:
254
+ logger.debug("fuse_attention: failed to match qk path")
255
+ return
256
+
257
+ q_nodes = self.model.match_parent_path(
258
+ matmul_qk, ["Reshape", "Transpose", "Reshape", "Add", "MatMul"], [0, 0, 0, 0, None]
259
+ )
260
+ if q_nodes is None:
261
+ logger.debug("fuse_attention: failed to match q path")
262
+ return
263
+ (_, _transpose_q, reshape_q, add_q, matmul_q) = q_nodes
264
+ k_nodes = self.model.match_parent_path(
265
+ matmul_qk, ["Transpose", "Reshape", "Transpose", "Reshape", "Add", "MatMul"], [1, 0, 0, 0, 0, None]
266
+ )
267
+ if k_nodes is None:
268
+ logger.debug("fuse_attention: failed to match k path")
269
+ return
270
+ (_, _, _, _, add_k, matmul_k) = k_nodes
271
+
272
+ attention_last_node = reshape_out
273
+
274
+ q_num_heads, q_hidden_size = self.get_num_heads_and_hidden_size(reshape_q, add_q)
275
+ if q_num_heads <= 0:
276
+ logger.debug("fuse_attention: failed to detect num_heads")
277
+ return
278
+
279
+ # number of heads are same for all the paths, hence to create attention node, we pass the q_num_heads
280
+ new_node = self.create_attention_node(
281
+ matmul_q,
282
+ add_q,
283
+ matmul_k,
284
+ add_k,
285
+ matmul_v,
286
+ add_v,
287
+ q_num_heads,
288
+ q_hidden_size,
289
+ matmul_q.input[0],
290
+ attention_last_node.output[0],
291
+ )
292
+ if new_node is None:
293
+ return
294
+
295
+ self.nodes_to_add.append(new_node)
296
+ self.node_name_to_graph_name[new_node.name] = self.this_graph_name
297
+
298
+ self.nodes_to_remove.extend([attention_last_node, transpose_qkv])
299
+
300
+ # Use prune graph to remove nodes since they are shared by all attention nodes.
301
+ self.prune_graph = True