onnxruntime-directml 1.20.0__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (305) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6508 -0
  4. onnxruntime/__init__.py +78 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +174 -0
  7. onnxruntime/backend/backend_rep.py +53 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  13. onnxruntime/capi/onnxruntime.dll +0 -0
  14. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  15. onnxruntime/capi/onnxruntime_inference_collection.py +1108 -0
  16. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  17. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  18. onnxruntime/capi/onnxruntime_validation.py +150 -0
  19. onnxruntime/capi/version_info.py +2 -0
  20. onnxruntime/datasets/__init__.py +17 -0
  21. onnxruntime/datasets/logreg_iris.onnx +0 -0
  22. onnxruntime/datasets/mul_1.onnx +0 -0
  23. onnxruntime/datasets/sigmoid.onnx +13 -0
  24. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  27. onnxruntime/quantization/__init__.py +16 -0
  28. onnxruntime/quantization/base_quantizer.py +532 -0
  29. onnxruntime/quantization/calibrate.py +1245 -0
  30. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  31. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  32. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  33. onnxruntime/quantization/execution_providers/qnn/preprocess.py +307 -0
  34. onnxruntime/quantization/execution_providers/qnn/quant_config.py +387 -0
  35. onnxruntime/quantization/fusions/__init__.py +3 -0
  36. onnxruntime/quantization/fusions/fusion.py +311 -0
  37. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  38. onnxruntime/quantization/fusions/fusion_layernorm.py +135 -0
  39. onnxruntime/quantization/matmul_4bits_quantizer.py +1480 -0
  40. onnxruntime/quantization/matmul_bnb4_quantizer.py +240 -0
  41. onnxruntime/quantization/onnx_model.py +580 -0
  42. onnxruntime/quantization/onnx_quantizer.py +1008 -0
  43. onnxruntime/quantization/operators/__init__.py +2 -0
  44. onnxruntime/quantization/operators/activation.py +119 -0
  45. onnxruntime/quantization/operators/argmax.py +18 -0
  46. onnxruntime/quantization/operators/attention.py +73 -0
  47. onnxruntime/quantization/operators/base_operator.py +26 -0
  48. onnxruntime/quantization/operators/binary_op.py +72 -0
  49. onnxruntime/quantization/operators/concat.py +62 -0
  50. onnxruntime/quantization/operators/conv.py +258 -0
  51. onnxruntime/quantization/operators/direct_q8.py +78 -0
  52. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  53. onnxruntime/quantization/operators/gather.py +64 -0
  54. onnxruntime/quantization/operators/gavgpool.py +62 -0
  55. onnxruntime/quantization/operators/gemm.py +166 -0
  56. onnxruntime/quantization/operators/lstm.py +117 -0
  57. onnxruntime/quantization/operators/matmul.py +231 -0
  58. onnxruntime/quantization/operators/maxpool.py +34 -0
  59. onnxruntime/quantization/operators/norm.py +40 -0
  60. onnxruntime/quantization/operators/pad.py +100 -0
  61. onnxruntime/quantization/operators/pooling.py +67 -0
  62. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  63. onnxruntime/quantization/operators/resize.py +34 -0
  64. onnxruntime/quantization/operators/softmax.py +74 -0
  65. onnxruntime/quantization/operators/split.py +63 -0
  66. onnxruntime/quantization/operators/where.py +87 -0
  67. onnxruntime/quantization/preprocess.py +141 -0
  68. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  69. onnxruntime/quantization/qdq_quantizer.py +1187 -0
  70. onnxruntime/quantization/quant_utils.py +891 -0
  71. onnxruntime/quantization/quantize.py +748 -0
  72. onnxruntime/quantization/registry.py +106 -0
  73. onnxruntime/quantization/shape_inference.py +187 -0
  74. onnxruntime/quantization/tensor_quant_overrides.py +516 -0
  75. onnxruntime/tools/__init__.py +10 -0
  76. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  77. onnxruntime/tools/convert_onnx_models_to_ort.py +377 -0
  78. onnxruntime/tools/file_utils.py +46 -0
  79. onnxruntime/tools/logger.py +11 -0
  80. onnxruntime/tools/make_dynamic_shape_fixed.py +72 -0
  81. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  82. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +33 -0
  83. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  84. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  85. onnxruntime/tools/mobile_helpers/usability_checker.py +739 -0
  86. onnxruntime/tools/offline_tuning.py +169 -0
  87. onnxruntime/tools/onnx_model_utils.py +413 -0
  88. onnxruntime/tools/onnx_randomizer.py +85 -0
  89. onnxruntime/tools/onnxruntime_test.py +164 -0
  90. onnxruntime/tools/optimize_onnx_model.py +55 -0
  91. onnxruntime/tools/ort_format_model/__init__.py +25 -0
  92. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +663 -0
  93. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  94. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  95. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  96. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  97. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  98. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  99. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  100. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  140. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  141. onnxruntime/tools/ort_format_model/types.py +84 -0
  142. onnxruntime/tools/ort_format_model/utils.py +62 -0
  143. onnxruntime/tools/pytorch_export_contrib_ops.py +108 -0
  144. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  145. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  146. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  147. onnxruntime/tools/reduced_build_config_parser.py +202 -0
  148. onnxruntime/tools/symbolic_shape_infer.py +3016 -0
  149. onnxruntime/tools/update_onnx_opset.py +31 -0
  150. onnxruntime/transformers/__init__.py +8 -0
  151. onnxruntime/transformers/affinity_helper.py +40 -0
  152. onnxruntime/transformers/benchmark.py +944 -0
  153. onnxruntime/transformers/benchmark_helper.py +646 -0
  154. onnxruntime/transformers/bert_perf_test.py +634 -0
  155. onnxruntime/transformers/bert_test_data.py +642 -0
  156. onnxruntime/transformers/compare_bert_results.py +246 -0
  157. onnxruntime/transformers/constants.py +47 -0
  158. onnxruntime/transformers/convert_generation.py +3124 -0
  159. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  160. onnxruntime/transformers/convert_to_packing_mode.py +387 -0
  161. onnxruntime/transformers/dynamo_onnx_helper.py +104 -0
  162. onnxruntime/transformers/float16.py +501 -0
  163. onnxruntime/transformers/fusion_attention.py +1235 -0
  164. onnxruntime/transformers/fusion_attention_clip.py +257 -0
  165. onnxruntime/transformers/fusion_attention_sam2.py +534 -0
  166. onnxruntime/transformers/fusion_attention_unet.py +1304 -0
  167. onnxruntime/transformers/fusion_attention_vae.py +301 -0
  168. onnxruntime/transformers/fusion_bart_attention.py +640 -0
  169. onnxruntime/transformers/fusion_base.py +137 -0
  170. onnxruntime/transformers/fusion_bias_add.py +58 -0
  171. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  172. onnxruntime/transformers/fusion_biassplitgelu.py +111 -0
  173. onnxruntime/transformers/fusion_conformer_attention.py +143 -0
  174. onnxruntime/transformers/fusion_embedlayer.py +811 -0
  175. onnxruntime/transformers/fusion_fastgelu.py +360 -0
  176. onnxruntime/transformers/fusion_gelu.py +259 -0
  177. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  178. onnxruntime/transformers/fusion_gemmfastgelu.py +122 -0
  179. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  180. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  181. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  182. onnxruntime/transformers/fusion_group_norm.py +179 -0
  183. onnxruntime/transformers/fusion_layernorm.py +465 -0
  184. onnxruntime/transformers/fusion_nhwc_conv.py +100 -0
  185. onnxruntime/transformers/fusion_options.py +340 -0
  186. onnxruntime/transformers/fusion_qordered_attention.py +421 -0
  187. onnxruntime/transformers/fusion_qordered_gelu.py +119 -0
  188. onnxruntime/transformers/fusion_qordered_layernorm.py +123 -0
  189. onnxruntime/transformers/fusion_qordered_matmul.py +217 -0
  190. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  191. onnxruntime/transformers/fusion_reshape.py +173 -0
  192. onnxruntime/transformers/fusion_rotary_attention.py +1592 -0
  193. onnxruntime/transformers/fusion_shape.py +110 -0
  194. onnxruntime/transformers/fusion_simplified_layernorm.py +159 -0
  195. onnxruntime/transformers/fusion_skip_group_norm.py +255 -0
  196. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  197. onnxruntime/transformers/fusion_transpose.py +168 -0
  198. onnxruntime/transformers/fusion_utils.py +307 -0
  199. onnxruntime/transformers/huggingface_models.py +167 -0
  200. onnxruntime/transformers/import_utils.py +20 -0
  201. onnxruntime/transformers/io_binding_helper.py +442 -0
  202. onnxruntime/transformers/large_model_exporter.py +395 -0
  203. onnxruntime/transformers/machine_info.py +221 -0
  204. onnxruntime/transformers/metrics.py +164 -0
  205. onnxruntime/transformers/models/bart/__init__.py +12 -0
  206. onnxruntime/transformers/models/bart/export.py +98 -0
  207. onnxruntime/transformers/models/bert/__init__.py +12 -0
  208. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  209. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  210. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  211. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +561 -0
  212. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1032 -0
  213. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  214. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  215. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  216. onnxruntime/transformers/models/llama/__init__.py +12 -0
  217. onnxruntime/transformers/models/llama/benchmark.py +703 -0
  218. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  219. onnxruntime/transformers/models/llama/benchmark_e2e.py +606 -0
  220. onnxruntime/transformers/models/llama/convert_to_onnx.py +1027 -0
  221. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  222. onnxruntime/transformers/models/llama/llama_inputs.py +503 -0
  223. onnxruntime/transformers/models/llama/llama_parity.py +309 -0
  224. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  225. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  226. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  227. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  228. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  229. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  230. onnxruntime/transformers/models/longformer/longformer_helper.py +77 -0
  231. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  232. onnxruntime/transformers/models/phi2/convert_to_onnx.py +576 -0
  233. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  234. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  235. onnxruntime/transformers/models/sam2/benchmark_sam2.py +625 -0
  236. onnxruntime/transformers/models/sam2/convert_to_onnx.py +260 -0
  237. onnxruntime/transformers/models/sam2/image_decoder.py +273 -0
  238. onnxruntime/transformers/models/sam2/image_encoder.py +186 -0
  239. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  240. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  241. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  242. onnxruntime/transformers/models/sam2/sam2_demo.py +322 -0
  243. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +280 -0
  244. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  245. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  246. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1429 -0
  247. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  248. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +102 -0
  249. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +268 -0
  250. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  251. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1319 -0
  252. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1181 -0
  253. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +296 -0
  254. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +388 -0
  255. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  256. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  257. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  258. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +350 -0
  259. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  260. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  261. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  262. onnxruntime/transformers/models/t5/__init__.py +12 -0
  263. onnxruntime/transformers/models/t5/convert_to_onnx.py +278 -0
  264. onnxruntime/transformers/models/t5/past_helper.py +150 -0
  265. onnxruntime/transformers/models/t5/t5_decoder.py +438 -0
  266. onnxruntime/transformers/models/t5/t5_encoder.py +171 -0
  267. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +299 -0
  268. onnxruntime/transformers/models/t5/t5_helper.py +272 -0
  269. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  270. onnxruntime/transformers/models/whisper/benchmark.py +610 -0
  271. onnxruntime/transformers/models/whisper/benchmark_all.py +528 -0
  272. onnxruntime/transformers/models/whisper/convert_to_onnx.py +536 -0
  273. onnxruntime/transformers/models/whisper/whisper_chain.py +329 -0
  274. onnxruntime/transformers/models/whisper/whisper_decoder.py +402 -0
  275. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  276. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +306 -0
  277. onnxruntime/transformers/models/whisper/whisper_helper.py +524 -0
  278. onnxruntime/transformers/models/whisper/whisper_openai_helper.py +84 -0
  279. onnxruntime/transformers/onnx_exporter.py +717 -0
  280. onnxruntime/transformers/onnx_model.py +1569 -0
  281. onnxruntime/transformers/onnx_model_bart.py +142 -0
  282. onnxruntime/transformers/onnx_model_bert.py +481 -0
  283. onnxruntime/transformers/onnx_model_bert_keras.py +475 -0
  284. onnxruntime/transformers/onnx_model_bert_tf.py +589 -0
  285. onnxruntime/transformers/onnx_model_clip.py +40 -0
  286. onnxruntime/transformers/onnx_model_conformer.py +33 -0
  287. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  288. onnxruntime/transformers/onnx_model_phi.py +930 -0
  289. onnxruntime/transformers/onnx_model_sam2.py +138 -0
  290. onnxruntime/transformers/onnx_model_t5.py +791 -0
  291. onnxruntime/transformers/onnx_model_tnlr.py +227 -0
  292. onnxruntime/transformers/onnx_model_unet.py +259 -0
  293. onnxruntime/transformers/onnx_model_vae.py +43 -0
  294. onnxruntime/transformers/onnx_utils.py +55 -0
  295. onnxruntime/transformers/optimizer.py +612 -0
  296. onnxruntime/transformers/profiler.py +725 -0
  297. onnxruntime/transformers/quantize_helper.py +76 -0
  298. onnxruntime/transformers/shape_infer_helper.py +122 -0
  299. onnxruntime/transformers/shape_optimizer.py +401 -0
  300. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  301. onnxruntime_directml-1.20.0.dist-info/METADATA +187 -0
  302. onnxruntime_directml-1.20.0.dist-info/RECORD +305 -0
  303. onnxruntime_directml-1.20.0.dist-info/WHEEL +5 -0
  304. onnxruntime_directml-1.20.0.dist-info/entry_points.txt +2 -0
  305. onnxruntime_directml-1.20.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,227 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+ import logging
6
+ from typing import Union
7
+
8
+ from fusion_attention import AttentionMask, FusionAttention
9
+ from fusion_utils import NumpyHelper
10
+ from onnx import NodeProto, helper
11
+ from onnx_model import OnnxModel
12
+ from onnx_model_bert import BertOnnxModel
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ class FusionTnlrAttention(FusionAttention):
18
+ """
19
+ Fuse TNLR Attention subgraph into one Attention node.
20
+ TNLR Attention has extra addition after qk nodes and adopts [S, B, NH] as I/O shape.
21
+ """
22
+
23
+ def __init__(
24
+ self,
25
+ model: OnnxModel,
26
+ hidden_size: int,
27
+ num_heads: int,
28
+ attention_mask: AttentionMask,
29
+ ):
30
+ super().__init__(model, hidden_size, num_heads, attention_mask)
31
+
32
+ def create_attention_node(
33
+ self,
34
+ mask_index: str,
35
+ matmul: NodeProto,
36
+ add: NodeProto,
37
+ num_heads: int,
38
+ hidden_size: int,
39
+ input: str,
40
+ output: str,
41
+ add_qk_str: str,
42
+ ) -> Union[NodeProto, None]:
43
+ assert num_heads > 0
44
+ if hidden_size > 0 and (hidden_size % num_heads) != 0:
45
+ logger.debug(f"input hidden size {hidden_size} is not a multiple of num of heads {num_heads}")
46
+ return None
47
+
48
+ weight = self.model.get_initializer(matmul.input[1])
49
+ bias = self.model.get_initializer(add.input[1]) or self.model.get_initializer(add.input[0])
50
+
51
+ if weight is None or bias is None:
52
+ return None
53
+
54
+ qkv_weight = NumpyHelper.to_array(weight)
55
+ qkv_bias = NumpyHelper.to_array(bias)
56
+
57
+ attention_node_name = self.model.create_node_name("Attention")
58
+
59
+ tensor_dtype = weight.data_type
60
+ np_type = helper.tensor_dtype_to_np_dtype(tensor_dtype)
61
+ weight = helper.make_tensor(
62
+ name=attention_node_name + "_qkv_weight",
63
+ data_type=tensor_dtype,
64
+ dims=[hidden_size, 3 * hidden_size],
65
+ vals=qkv_weight.astype(np_type).tobytes(),
66
+ raw=True,
67
+ )
68
+ self.model.add_initializer(weight, self.this_graph_name)
69
+
70
+ bias = helper.make_tensor(
71
+ name=attention_node_name + "_qkv_bias",
72
+ data_type=tensor_dtype,
73
+ dims=[3 * hidden_size],
74
+ vals=qkv_bias.astype(np_type).tobytes(),
75
+ raw=True,
76
+ )
77
+ self.model.add_initializer(bias, self.this_graph_name)
78
+
79
+ attention_inputs = [
80
+ input,
81
+ attention_node_name + "_qkv_weight",
82
+ attention_node_name + "_qkv_bias",
83
+ ]
84
+ if mask_index is not None:
85
+ attention_inputs.append(mask_index)
86
+ else:
87
+ attention_inputs.append("")
88
+
89
+ if add_qk_str is not None:
90
+ attention_inputs.append("")
91
+ attention_inputs.append(add_qk_str)
92
+
93
+ attention_node = helper.make_node(
94
+ "Attention",
95
+ inputs=attention_inputs,
96
+ outputs=[output],
97
+ name=attention_node_name,
98
+ )
99
+ attention_node.domain = "com.microsoft"
100
+ attention_node.attribute.extend([helper.make_attribute("num_heads", num_heads)])
101
+
102
+ return attention_node
103
+
104
+ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
105
+ # Sometimes we can not fuse skiplayernormalization since the add before layernorm has an output that used by nodes outside skiplayernorm
106
+ # Conceptually we treat add before layernorm as skiplayernorm node since they share the same pattern
107
+ start_node = normalize_node
108
+ if normalize_node.op_type != "SkipLayerNormalization":
109
+ return
110
+
111
+ # SkipLayerNormalization has two inputs, and one of them is the root input for attention.
112
+ qkv_nodes = self.model.match_parent_path(
113
+ start_node,
114
+ ["Where", "Add", "MatMul", "Reshape", "Transpose", "MatMul"],
115
+ [1, 1, 1, 0, 0, 0],
116
+ )
117
+ if qkv_nodes is not None:
118
+ (_, _, matmul_below, reshape_qkv, transpose_qkv, matmul_qkv) = qkv_nodes
119
+ else:
120
+ return
121
+
122
+ other_inputs = []
123
+ for _i, input in enumerate(start_node.input):
124
+ if input not in output_name_to_node:
125
+ continue
126
+
127
+ if input == qkv_nodes[0].output[0]:
128
+ continue
129
+ other_inputs.append(input)
130
+ if len(other_inputs) != 1:
131
+ return
132
+
133
+ root_input = other_inputs[0]
134
+
135
+ v_nodes = self.model.match_parent_path(
136
+ matmul_qkv,
137
+ ["Transpose", "Reshape", "Slice", "Add", "MatMul"],
138
+ [1, 0, 0, 0, 1],
139
+ )
140
+ if v_nodes is None:
141
+ return
142
+ (_, _, _, add, matmul) = v_nodes
143
+
144
+ upper_nodes = self.model.match_parent_path(matmul, ["Transpose"], [0])
145
+ transpose = upper_nodes[0]
146
+
147
+ qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Add", "MatMul"], [0, 0, 0])
148
+ if qk_nodes is None:
149
+ return
150
+ (_, add_qk, matmul_qk) = qk_nodes
151
+
152
+ q_nodes = self.model.match_parent_path(
153
+ matmul_qk,
154
+ ["Mul", "Transpose", "Reshape", "Slice", "Add", "MatMul"],
155
+ [0, 0, 0, 0, 0, 1],
156
+ )
157
+ if q_nodes is None:
158
+ return
159
+ add = q_nodes[-2]
160
+ matmul = q_nodes[-1]
161
+
162
+ k_nodes = self.model.match_parent_path(
163
+ matmul_qk,
164
+ ["Transpose", "Reshape", "Slice", "Add", "MatMul"],
165
+ [1, 0, 0, 0, 1],
166
+ )
167
+ if k_nodes is None:
168
+ return
169
+ add = k_nodes[-2]
170
+ matmul = k_nodes[-1]
171
+
172
+ relative_position_bias_nodes = self.model.match_parent_path(add_qk, ["Reshape", "Where"], [1, 0])
173
+ if relative_position_bias_nodes is None:
174
+ return
175
+
176
+ if matmul.input[0] == root_input:
177
+ mask_index = None
178
+ attention_last_node = reshape_qkv
179
+ # number of heads are same for all the paths, hence to create attention node, we pass the q_num_heads
180
+ # the input_hidden_size represents the input hidden size, this is used as needed but hidden sizes for Q, K are extracted appropriately
181
+ new_node = self.create_attention_node(
182
+ mask_index,
183
+ matmul,
184
+ add,
185
+ self.num_heads,
186
+ self.hidden_size,
187
+ root_input,
188
+ attention_last_node.output[0],
189
+ relative_position_bias_nodes[0].input[0],
190
+ )
191
+ if new_node is None:
192
+ return
193
+
194
+ self.nodes_to_add.append(new_node)
195
+ self.node_name_to_graph_name[new_node.name] = self.this_graph_name
196
+
197
+ # Add a transpose node after the attention node
198
+ back_transpose = helper.make_node(
199
+ "Transpose",
200
+ ["back_transpose_in_" + new_node.name],
201
+ [new_node.output[0]],
202
+ "back_transpose_" + new_node.name,
203
+ perm=[1, 0, 2],
204
+ )
205
+ self.model.add_node(back_transpose, self.this_graph_name)
206
+ new_node.input[0] = transpose.input[0]
207
+ new_node.output[0] = "back_transpose_in_" + new_node.name
208
+
209
+ self.nodes_to_remove.extend([attention_last_node, transpose_qkv, matmul_qkv])
210
+ self.nodes_to_remove.extend(qk_nodes)
211
+ self.nodes_to_remove.extend(q_nodes)
212
+ self.nodes_to_remove.extend(k_nodes)
213
+ self.nodes_to_remove.extend(v_nodes)
214
+
215
+ # Use prune graph to remove mask nodes since they are shared by all attention nodes.
216
+ # self.nodes_to_remove.extend(mask_nodes)
217
+ self.prune_graph = True
218
+
219
+
220
+ class TnlrOnnxModel(BertOnnxModel):
221
+ def __init__(self, model, num_heads, hidden_size):
222
+ super().__init__(model, num_heads, hidden_size)
223
+ self.attention_mask = AttentionMask(self)
224
+ self.attention_fusion = FusionTnlrAttention(self, self.hidden_size, self.num_heads, self.attention_mask)
225
+
226
+ def fuse_attention(self):
227
+ self.attention_fusion.apply()
@@ -0,0 +1,259 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+
6
+ import logging
7
+ from typing import Optional
8
+
9
+ from fusion_attention_unet import FusionAttentionUnet
10
+ from fusion_bias_add import FusionBiasAdd
11
+ from fusion_biassplitgelu import FusionBiasSplitGelu
12
+ from fusion_group_norm import FusionGroupNorm
13
+ from fusion_nhwc_conv import FusionNhwcConv
14
+ from fusion_options import FusionOptions
15
+ from fusion_skip_group_norm import FusionSkipGroupNorm
16
+ from fusion_transpose import FusionInsertTranspose, FusionTranspose
17
+ from import_utils import is_installed
18
+ from onnx import ModelProto
19
+ from onnx_model import OnnxModel
20
+ from onnx_model_bert import BertOnnxModel
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ class UnetOnnxModel(BertOnnxModel):
26
+ def __init__(self, model: ModelProto, num_heads: int = 0, hidden_size: int = 0):
27
+ """Initialize UNet ONNX Model.
28
+
29
+ Args:
30
+ model (ModelProto): the ONNX model
31
+ num_heads (int, optional): number of attention heads. Defaults to 0 (detect the parameter automatically).
32
+ hidden_size (int, optional): hidden dimension. Defaults to 0 (detect the parameter automatically).
33
+ """
34
+ assert (num_heads == 0 and hidden_size == 0) or (num_heads > 0 and hidden_size % num_heads == 0)
35
+
36
+ super().__init__(model, num_heads=num_heads, hidden_size=hidden_size)
37
+
38
+ def preprocess(self):
39
+ self.remove_useless_div()
40
+
41
+ def postprocess(self):
42
+ self.prune_graph()
43
+ self.remove_unused_constant()
44
+
45
+ def remove_useless_div(self):
46
+ """Remove Div by 1"""
47
+ div_nodes = [node for node in self.nodes() if node.op_type == "Div"]
48
+
49
+ nodes_to_remove = []
50
+ for div in div_nodes:
51
+ if self.find_constant_input(div, 1.0) == 1:
52
+ nodes_to_remove.append(div)
53
+
54
+ for node in nodes_to_remove:
55
+ self.replace_input_of_all_nodes(node.output[0], node.input[0])
56
+
57
+ if nodes_to_remove:
58
+ self.remove_nodes(nodes_to_remove)
59
+ logger.info("Removed %d Div nodes", len(nodes_to_remove))
60
+
61
+ def convert_conv_to_nhwc(self):
62
+ # Transpose weights in offline might help since ORT does not apply constant-folding on Transpose nodes.
63
+ conv_to_nhwc_conv = FusionNhwcConv(self, update_weight=True)
64
+ conv_to_nhwc_conv.apply()
65
+
66
+ def merge_adjacent_transpose(self):
67
+ fusion_transpose = FusionTranspose(self)
68
+ fusion_transpose.apply()
69
+
70
+ remove_count = 0
71
+ nodes = self.get_nodes_by_op_type("Transpose")
72
+ for node in nodes:
73
+ permutation = OnnxModel.get_node_attribute(node, "perm")
74
+ assert isinstance(permutation, list)
75
+ if permutation != list(range(len(permutation))):
76
+ continue
77
+ assert not (
78
+ self.find_graph_output(node.output[0])
79
+ or self.find_graph_input(node.input[0])
80
+ or self.find_graph_output(node.input[0])
81
+ )
82
+
83
+ # Let all children nodes skip current Transpose node and link to its parent
84
+ # Note that we cannot update parent node output since parent node might have more than one children.
85
+ self.replace_input_of_all_nodes(node.output[0], node.input[0])
86
+
87
+ self.remove_node(node)
88
+ remove_count += 1
89
+
90
+ total = len(fusion_transpose.nodes_to_remove) + remove_count
91
+ if total:
92
+ logger.info("Removed %d Transpose nodes", total)
93
+
94
+ def fuse_multi_head_attention(self, options: Optional[FusionOptions] = None):
95
+ # Self Attention
96
+ enable_packed_qkv = (options is None) or options.enable_packed_qkv
97
+ self_attention_fusion = FusionAttentionUnet(
98
+ self,
99
+ self.hidden_size,
100
+ self.num_heads,
101
+ is_cross_attention=False,
102
+ enable_packed_qkv=enable_packed_qkv,
103
+ enable_packed_kv=False,
104
+ )
105
+ self_attention_fusion.apply()
106
+
107
+ # Cross Attention
108
+ enable_packed_kv = (options is None) or options.enable_packed_kv
109
+ cross_attention_fusion = FusionAttentionUnet(
110
+ self,
111
+ self.hidden_size,
112
+ self.num_heads,
113
+ is_cross_attention=True,
114
+ enable_packed_qkv=False,
115
+ enable_packed_kv=enable_packed_kv,
116
+ )
117
+ cross_attention_fusion.apply()
118
+
119
+ def fuse_bias_add(self):
120
+ fusion = FusionBiasAdd(self)
121
+ fusion.apply()
122
+
123
+ def optimize(self, options: Optional[FusionOptions] = None):
124
+ if is_installed("tqdm"):
125
+ import tqdm
126
+ from tqdm.contrib.logging import logging_redirect_tqdm
127
+
128
+ with logging_redirect_tqdm():
129
+ steps = 18
130
+ progress_bar = tqdm.tqdm(range(steps), initial=0, desc="fusion")
131
+ self._optimize(options, progress_bar)
132
+ else:
133
+ logger.info("tqdm is not installed. Run optimization without progress bar")
134
+ self._optimize(options, None)
135
+
136
+ def _optimize(self, options: Optional[FusionOptions] = None, progress_bar=None):
137
+ if (options is not None) and not options.enable_shape_inference:
138
+ self.disable_shape_inference()
139
+
140
+ self.utils.remove_identity_nodes()
141
+ if progress_bar:
142
+ progress_bar.update(1)
143
+
144
+ # Remove cast nodes that having same data type of input and output based on symbolic shape inference.
145
+ self.utils.remove_useless_cast_nodes()
146
+ if progress_bar:
147
+ progress_bar.update(1)
148
+
149
+ if (options is None) or options.enable_layer_norm:
150
+ self.fuse_layer_norm()
151
+ if progress_bar:
152
+ progress_bar.update(1)
153
+
154
+ if (options is None) or options.enable_gelu:
155
+ self.fuse_gelu()
156
+ if progress_bar:
157
+ progress_bar.update(1)
158
+
159
+ self.preprocess()
160
+ if progress_bar:
161
+ progress_bar.update(1)
162
+
163
+ self.fuse_reshape()
164
+ if progress_bar:
165
+ progress_bar.update(1)
166
+
167
+ if (options is None) or options.enable_group_norm:
168
+ channels_last = (options is None) or options.group_norm_channels_last
169
+ group_norm_fusion = FusionGroupNorm(self, channels_last)
170
+ group_norm_fusion.apply()
171
+
172
+ insert_transpose_fusion = FusionInsertTranspose(self)
173
+ insert_transpose_fusion.apply()
174
+ if progress_bar:
175
+ progress_bar.update(1)
176
+
177
+ if (options is None) or options.enable_bias_splitgelu:
178
+ bias_split_gelu_fusion = FusionBiasSplitGelu(self)
179
+ bias_split_gelu_fusion.apply()
180
+ if progress_bar:
181
+ progress_bar.update(1)
182
+
183
+ if (options is None) or options.enable_attention:
184
+ # self.save_model_to_file("before_mha.onnx")
185
+ self.fuse_multi_head_attention(options)
186
+ if progress_bar:
187
+ progress_bar.update(1)
188
+
189
+ if (options is None) or options.enable_skip_layer_norm:
190
+ self.fuse_skip_layer_norm()
191
+ if progress_bar:
192
+ progress_bar.update(1)
193
+
194
+ self.fuse_shape()
195
+ if progress_bar:
196
+ progress_bar.update(1)
197
+
198
+ # Remove reshape nodes that having same shape of input and output based on symbolic shape inference.
199
+ self.utils.remove_useless_reshape_nodes()
200
+ if progress_bar:
201
+ progress_bar.update(1)
202
+
203
+ if (options is None) or options.enable_skip_group_norm:
204
+ skip_group_norm_fusion = FusionSkipGroupNorm(self)
205
+ skip_group_norm_fusion.apply()
206
+ if progress_bar:
207
+ progress_bar.update(1)
208
+
209
+ if (options is None) or options.enable_bias_skip_layer_norm:
210
+ # Fuse SkipLayerNormalization and Add Bias before it.
211
+ self.fuse_add_bias_skip_layer_norm()
212
+ if progress_bar:
213
+ progress_bar.update(1)
214
+
215
+ if options is not None and options.enable_gelu_approximation:
216
+ self.gelu_approximation()
217
+ if progress_bar:
218
+ progress_bar.update(1)
219
+
220
+ if options is None or options.enable_nhwc_conv:
221
+ self.convert_conv_to_nhwc()
222
+ self.merge_adjacent_transpose()
223
+ if progress_bar:
224
+ progress_bar.update(1)
225
+
226
+ if options is not None and options.enable_bias_add:
227
+ self.fuse_bias_add()
228
+ if progress_bar:
229
+ progress_bar.update(1)
230
+
231
+ self.postprocess()
232
+ if progress_bar:
233
+ progress_bar.update(1)
234
+
235
+ logger.info(f"opset version: {self.get_opset_version()}")
236
+
237
+ def get_fused_operator_statistics(self):
238
+ """
239
+ Returns node count of fused operators.
240
+ """
241
+ op_count = {}
242
+ ops = [
243
+ "Attention",
244
+ "MultiHeadAttention",
245
+ "LayerNormalization",
246
+ "SkipLayerNormalization",
247
+ "BiasSplitGelu",
248
+ "GroupNorm",
249
+ "SkipGroupNorm",
250
+ "NhwcConv",
251
+ "BiasAdd",
252
+ ]
253
+
254
+ for op in ops:
255
+ nodes = self.get_nodes_by_op_type(op)
256
+ op_count[op] = len(nodes)
257
+
258
+ logger.info(f"Optimized operators:{op_count}")
259
+ return op_count
@@ -0,0 +1,43 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+
6
+ from logging import getLogger
7
+ from typing import Optional
8
+
9
+ from fusion_attention_vae import FusionAttentionVae
10
+ from fusion_options import FusionOptions
11
+ from onnx import ModelProto
12
+ from onnx_model_unet import UnetOnnxModel
13
+
14
+ logger = getLogger(__name__)
15
+
16
+
17
+ class VaeOnnxModel(UnetOnnxModel):
18
+ def __init__(self, model: ModelProto, num_heads: int = 0, hidden_size: int = 0):
19
+ assert (num_heads == 0 and hidden_size == 0) or (num_heads > 0 and hidden_size % num_heads == 0)
20
+ super().__init__(model, num_heads=num_heads, hidden_size=hidden_size)
21
+
22
+ def fuse_multi_head_attention(self, options: Optional[FusionOptions] = None):
23
+ # Self Attention
24
+ self_attention_fusion = FusionAttentionVae(self, self.hidden_size, self.num_heads)
25
+ self_attention_fusion.apply()
26
+
27
+ def get_fused_operator_statistics(self):
28
+ """
29
+ Returns node count of fused operators.
30
+ """
31
+ op_count = {}
32
+ ops = [
33
+ "Attention",
34
+ "GroupNorm",
35
+ "SkipGroupNorm",
36
+ "NhwcConv",
37
+ ]
38
+ for op in ops:
39
+ nodes = self.get_nodes_by_op_type(op)
40
+ op_count[op] = len(nodes)
41
+
42
+ logger.info(f"Optimized operators:{op_count}")
43
+ return op_count
@@ -0,0 +1,55 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+ from fusion_utils import NumpyHelper
6
+ from onnx import ModelProto, TensorProto
7
+ from onnx.external_data_helper import set_external_data
8
+ from onnx_model import OnnxModel
9
+
10
+ from onnxruntime import OrtValue
11
+
12
+
13
+ def extract_raw_data_from_model(model: ModelProto):
14
+ """
15
+ Extract external data from model and return the external data as a list of tuples (name, value).
16
+ Note this function does not handle external data that is not loaded into the model as raw data.
17
+
18
+ Args:
19
+ model (ModelProto): the model proto to extract external data from.
20
+ Returns:
21
+ (external_names, external_values): a tuple of two lists of external data names and values.
22
+ """
23
+ external_data = []
24
+ onnx_model = OnnxModel(model)
25
+ for graph in onnx_model.graphs():
26
+ for initializer in graph.initializer:
27
+ name = initializer.name
28
+
29
+ if initializer.HasField("raw_data"):
30
+ numpy_tensor = NumpyHelper.to_array(initializer)
31
+ ort_value = OrtValue.ortvalue_from_numpy(numpy_tensor)
32
+ external_data.append((name, ort_value))
33
+ # mimic set_external_data
34
+ set_external_data(initializer, location="foo.bin")
35
+ initializer.name = name
36
+ initializer.ClearField("raw_data")
37
+
38
+ return zip(*external_data)
39
+
40
+
41
+ def has_external_data(model: ModelProto):
42
+ """
43
+ Check if the model has external data.
44
+
45
+ Args:
46
+ model (ModelProto): the model proto to check for external data.
47
+ Returns:
48
+ bool: True if the model has external data, False otherwise.
49
+ """
50
+ onnx_model = OnnxModel(model)
51
+ for graph in onnx_model.graphs():
52
+ for initializer in graph.initializer:
53
+ if initializer.HasField("data_location") and initializer.data_location == TensorProto.EXTERNAL:
54
+ return True
55
+ return False