onnxruntime-directml 1.20.0__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (305) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6508 -0
  4. onnxruntime/__init__.py +78 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +174 -0
  7. onnxruntime/backend/backend_rep.py +53 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  13. onnxruntime/capi/onnxruntime.dll +0 -0
  14. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  15. onnxruntime/capi/onnxruntime_inference_collection.py +1108 -0
  16. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  17. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  18. onnxruntime/capi/onnxruntime_validation.py +150 -0
  19. onnxruntime/capi/version_info.py +2 -0
  20. onnxruntime/datasets/__init__.py +17 -0
  21. onnxruntime/datasets/logreg_iris.onnx +0 -0
  22. onnxruntime/datasets/mul_1.onnx +0 -0
  23. onnxruntime/datasets/sigmoid.onnx +13 -0
  24. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  27. onnxruntime/quantization/__init__.py +16 -0
  28. onnxruntime/quantization/base_quantizer.py +532 -0
  29. onnxruntime/quantization/calibrate.py +1245 -0
  30. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  31. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  32. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  33. onnxruntime/quantization/execution_providers/qnn/preprocess.py +307 -0
  34. onnxruntime/quantization/execution_providers/qnn/quant_config.py +387 -0
  35. onnxruntime/quantization/fusions/__init__.py +3 -0
  36. onnxruntime/quantization/fusions/fusion.py +311 -0
  37. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  38. onnxruntime/quantization/fusions/fusion_layernorm.py +135 -0
  39. onnxruntime/quantization/matmul_4bits_quantizer.py +1480 -0
  40. onnxruntime/quantization/matmul_bnb4_quantizer.py +240 -0
  41. onnxruntime/quantization/onnx_model.py +580 -0
  42. onnxruntime/quantization/onnx_quantizer.py +1008 -0
  43. onnxruntime/quantization/operators/__init__.py +2 -0
  44. onnxruntime/quantization/operators/activation.py +119 -0
  45. onnxruntime/quantization/operators/argmax.py +18 -0
  46. onnxruntime/quantization/operators/attention.py +73 -0
  47. onnxruntime/quantization/operators/base_operator.py +26 -0
  48. onnxruntime/quantization/operators/binary_op.py +72 -0
  49. onnxruntime/quantization/operators/concat.py +62 -0
  50. onnxruntime/quantization/operators/conv.py +258 -0
  51. onnxruntime/quantization/operators/direct_q8.py +78 -0
  52. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  53. onnxruntime/quantization/operators/gather.py +64 -0
  54. onnxruntime/quantization/operators/gavgpool.py +62 -0
  55. onnxruntime/quantization/operators/gemm.py +166 -0
  56. onnxruntime/quantization/operators/lstm.py +117 -0
  57. onnxruntime/quantization/operators/matmul.py +231 -0
  58. onnxruntime/quantization/operators/maxpool.py +34 -0
  59. onnxruntime/quantization/operators/norm.py +40 -0
  60. onnxruntime/quantization/operators/pad.py +100 -0
  61. onnxruntime/quantization/operators/pooling.py +67 -0
  62. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  63. onnxruntime/quantization/operators/resize.py +34 -0
  64. onnxruntime/quantization/operators/softmax.py +74 -0
  65. onnxruntime/quantization/operators/split.py +63 -0
  66. onnxruntime/quantization/operators/where.py +87 -0
  67. onnxruntime/quantization/preprocess.py +141 -0
  68. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  69. onnxruntime/quantization/qdq_quantizer.py +1187 -0
  70. onnxruntime/quantization/quant_utils.py +891 -0
  71. onnxruntime/quantization/quantize.py +748 -0
  72. onnxruntime/quantization/registry.py +106 -0
  73. onnxruntime/quantization/shape_inference.py +187 -0
  74. onnxruntime/quantization/tensor_quant_overrides.py +516 -0
  75. onnxruntime/tools/__init__.py +10 -0
  76. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  77. onnxruntime/tools/convert_onnx_models_to_ort.py +377 -0
  78. onnxruntime/tools/file_utils.py +46 -0
  79. onnxruntime/tools/logger.py +11 -0
  80. onnxruntime/tools/make_dynamic_shape_fixed.py +72 -0
  81. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  82. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +33 -0
  83. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  84. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  85. onnxruntime/tools/mobile_helpers/usability_checker.py +739 -0
  86. onnxruntime/tools/offline_tuning.py +169 -0
  87. onnxruntime/tools/onnx_model_utils.py +413 -0
  88. onnxruntime/tools/onnx_randomizer.py +85 -0
  89. onnxruntime/tools/onnxruntime_test.py +164 -0
  90. onnxruntime/tools/optimize_onnx_model.py +55 -0
  91. onnxruntime/tools/ort_format_model/__init__.py +25 -0
  92. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +663 -0
  93. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  94. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  95. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  96. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  97. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  98. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  99. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  100. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  140. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  141. onnxruntime/tools/ort_format_model/types.py +84 -0
  142. onnxruntime/tools/ort_format_model/utils.py +62 -0
  143. onnxruntime/tools/pytorch_export_contrib_ops.py +108 -0
  144. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  145. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  146. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  147. onnxruntime/tools/reduced_build_config_parser.py +202 -0
  148. onnxruntime/tools/symbolic_shape_infer.py +3016 -0
  149. onnxruntime/tools/update_onnx_opset.py +31 -0
  150. onnxruntime/transformers/__init__.py +8 -0
  151. onnxruntime/transformers/affinity_helper.py +40 -0
  152. onnxruntime/transformers/benchmark.py +944 -0
  153. onnxruntime/transformers/benchmark_helper.py +646 -0
  154. onnxruntime/transformers/bert_perf_test.py +634 -0
  155. onnxruntime/transformers/bert_test_data.py +642 -0
  156. onnxruntime/transformers/compare_bert_results.py +246 -0
  157. onnxruntime/transformers/constants.py +47 -0
  158. onnxruntime/transformers/convert_generation.py +3124 -0
  159. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  160. onnxruntime/transformers/convert_to_packing_mode.py +387 -0
  161. onnxruntime/transformers/dynamo_onnx_helper.py +104 -0
  162. onnxruntime/transformers/float16.py +501 -0
  163. onnxruntime/transformers/fusion_attention.py +1235 -0
  164. onnxruntime/transformers/fusion_attention_clip.py +257 -0
  165. onnxruntime/transformers/fusion_attention_sam2.py +534 -0
  166. onnxruntime/transformers/fusion_attention_unet.py +1304 -0
  167. onnxruntime/transformers/fusion_attention_vae.py +301 -0
  168. onnxruntime/transformers/fusion_bart_attention.py +640 -0
  169. onnxruntime/transformers/fusion_base.py +137 -0
  170. onnxruntime/transformers/fusion_bias_add.py +58 -0
  171. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  172. onnxruntime/transformers/fusion_biassplitgelu.py +111 -0
  173. onnxruntime/transformers/fusion_conformer_attention.py +143 -0
  174. onnxruntime/transformers/fusion_embedlayer.py +811 -0
  175. onnxruntime/transformers/fusion_fastgelu.py +360 -0
  176. onnxruntime/transformers/fusion_gelu.py +259 -0
  177. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  178. onnxruntime/transformers/fusion_gemmfastgelu.py +122 -0
  179. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  180. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  181. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  182. onnxruntime/transformers/fusion_group_norm.py +179 -0
  183. onnxruntime/transformers/fusion_layernorm.py +465 -0
  184. onnxruntime/transformers/fusion_nhwc_conv.py +100 -0
  185. onnxruntime/transformers/fusion_options.py +340 -0
  186. onnxruntime/transformers/fusion_qordered_attention.py +421 -0
  187. onnxruntime/transformers/fusion_qordered_gelu.py +119 -0
  188. onnxruntime/transformers/fusion_qordered_layernorm.py +123 -0
  189. onnxruntime/transformers/fusion_qordered_matmul.py +217 -0
  190. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  191. onnxruntime/transformers/fusion_reshape.py +173 -0
  192. onnxruntime/transformers/fusion_rotary_attention.py +1592 -0
  193. onnxruntime/transformers/fusion_shape.py +110 -0
  194. onnxruntime/transformers/fusion_simplified_layernorm.py +159 -0
  195. onnxruntime/transformers/fusion_skip_group_norm.py +255 -0
  196. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  197. onnxruntime/transformers/fusion_transpose.py +168 -0
  198. onnxruntime/transformers/fusion_utils.py +307 -0
  199. onnxruntime/transformers/huggingface_models.py +167 -0
  200. onnxruntime/transformers/import_utils.py +20 -0
  201. onnxruntime/transformers/io_binding_helper.py +442 -0
  202. onnxruntime/transformers/large_model_exporter.py +395 -0
  203. onnxruntime/transformers/machine_info.py +221 -0
  204. onnxruntime/transformers/metrics.py +164 -0
  205. onnxruntime/transformers/models/bart/__init__.py +12 -0
  206. onnxruntime/transformers/models/bart/export.py +98 -0
  207. onnxruntime/transformers/models/bert/__init__.py +12 -0
  208. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  209. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  210. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  211. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +561 -0
  212. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1032 -0
  213. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  214. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  215. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  216. onnxruntime/transformers/models/llama/__init__.py +12 -0
  217. onnxruntime/transformers/models/llama/benchmark.py +703 -0
  218. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  219. onnxruntime/transformers/models/llama/benchmark_e2e.py +606 -0
  220. onnxruntime/transformers/models/llama/convert_to_onnx.py +1027 -0
  221. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  222. onnxruntime/transformers/models/llama/llama_inputs.py +503 -0
  223. onnxruntime/transformers/models/llama/llama_parity.py +309 -0
  224. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  225. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  226. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  227. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  228. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  229. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  230. onnxruntime/transformers/models/longformer/longformer_helper.py +77 -0
  231. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  232. onnxruntime/transformers/models/phi2/convert_to_onnx.py +576 -0
  233. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  234. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  235. onnxruntime/transformers/models/sam2/benchmark_sam2.py +625 -0
  236. onnxruntime/transformers/models/sam2/convert_to_onnx.py +260 -0
  237. onnxruntime/transformers/models/sam2/image_decoder.py +273 -0
  238. onnxruntime/transformers/models/sam2/image_encoder.py +186 -0
  239. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  240. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  241. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  242. onnxruntime/transformers/models/sam2/sam2_demo.py +322 -0
  243. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +280 -0
  244. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  245. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  246. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1429 -0
  247. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  248. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +102 -0
  249. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +268 -0
  250. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  251. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1319 -0
  252. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1181 -0
  253. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +296 -0
  254. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +388 -0
  255. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  256. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  257. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  258. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +350 -0
  259. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  260. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  261. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  262. onnxruntime/transformers/models/t5/__init__.py +12 -0
  263. onnxruntime/transformers/models/t5/convert_to_onnx.py +278 -0
  264. onnxruntime/transformers/models/t5/past_helper.py +150 -0
  265. onnxruntime/transformers/models/t5/t5_decoder.py +438 -0
  266. onnxruntime/transformers/models/t5/t5_encoder.py +171 -0
  267. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +299 -0
  268. onnxruntime/transformers/models/t5/t5_helper.py +272 -0
  269. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  270. onnxruntime/transformers/models/whisper/benchmark.py +610 -0
  271. onnxruntime/transformers/models/whisper/benchmark_all.py +528 -0
  272. onnxruntime/transformers/models/whisper/convert_to_onnx.py +536 -0
  273. onnxruntime/transformers/models/whisper/whisper_chain.py +329 -0
  274. onnxruntime/transformers/models/whisper/whisper_decoder.py +402 -0
  275. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  276. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +306 -0
  277. onnxruntime/transformers/models/whisper/whisper_helper.py +524 -0
  278. onnxruntime/transformers/models/whisper/whisper_openai_helper.py +84 -0
  279. onnxruntime/transformers/onnx_exporter.py +717 -0
  280. onnxruntime/transformers/onnx_model.py +1569 -0
  281. onnxruntime/transformers/onnx_model_bart.py +142 -0
  282. onnxruntime/transformers/onnx_model_bert.py +481 -0
  283. onnxruntime/transformers/onnx_model_bert_keras.py +475 -0
  284. onnxruntime/transformers/onnx_model_bert_tf.py +589 -0
  285. onnxruntime/transformers/onnx_model_clip.py +40 -0
  286. onnxruntime/transformers/onnx_model_conformer.py +33 -0
  287. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  288. onnxruntime/transformers/onnx_model_phi.py +930 -0
  289. onnxruntime/transformers/onnx_model_sam2.py +138 -0
  290. onnxruntime/transformers/onnx_model_t5.py +791 -0
  291. onnxruntime/transformers/onnx_model_tnlr.py +227 -0
  292. onnxruntime/transformers/onnx_model_unet.py +259 -0
  293. onnxruntime/transformers/onnx_model_vae.py +43 -0
  294. onnxruntime/transformers/onnx_utils.py +55 -0
  295. onnxruntime/transformers/optimizer.py +612 -0
  296. onnxruntime/transformers/profiler.py +725 -0
  297. onnxruntime/transformers/quantize_helper.py +76 -0
  298. onnxruntime/transformers/shape_infer_helper.py +122 -0
  299. onnxruntime/transformers/shape_optimizer.py +401 -0
  300. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  301. onnxruntime_directml-1.20.0.dist-info/METADATA +187 -0
  302. onnxruntime_directml-1.20.0.dist-info/RECORD +305 -0
  303. onnxruntime_directml-1.20.0.dist-info/WHEEL +5 -0
  304. onnxruntime_directml-1.20.0.dist-info/entry_points.txt +2 -0
  305. onnxruntime_directml-1.20.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,217 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+
6
+ from logging import getLogger
7
+ from typing import Dict
8
+
9
+ from fusion_base import Fusion
10
+ from fusion_utils import FusionUtils
11
+ from onnx import helper
12
+ from onnx_model import OnnxModel
13
+
14
+ logger = getLogger(__name__)
15
+
16
+
17
+ class FusionQOrderedMatMul(Fusion):
18
+ def __init__(self, model: OnnxModel):
19
+ super().__init__(model, "QOrderedMatMul", "MatMul")
20
+
21
+ def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict):
22
+ matmul_children = self.model.get_children(node, input_name_to_nodes)
23
+
24
+ # Should only have 1 child - Bias Add
25
+ if len(matmul_children) != 1 or matmul_children[0].op_type != "Add":
26
+ return
27
+
28
+ bias_add_node = matmul_children[0]
29
+
30
+ # Atleast one of the inputs to Bias Add node must be a constant
31
+ bias_add_node_index = 0
32
+ if (
33
+ self.model.get_constant_value(bias_add_node.input[0]) is None
34
+ and self.model.get_constant_value(bias_add_node.input[1]) is None
35
+ ):
36
+ return
37
+
38
+ if self.model.get_constant_value(bias_add_node.input[0]) is None:
39
+ bias_add_node_index = 1
40
+
41
+ bias_add_children = self.model.get_children(bias_add_node, input_name_to_nodes)
42
+
43
+ if len(bias_add_children) != 1:
44
+ return
45
+
46
+ bias_add_child = bias_add_children[0]
47
+
48
+ # Bias Add can have another Add downstream (Residual Add layer)
49
+ residual_add_node = None
50
+
51
+ downstream_quantize_node = None
52
+
53
+ if bias_add_child.op_type == "Add":
54
+ residual_add_node = bias_add_child
55
+
56
+ residual_add_children = self.model.get_children(residual_add_node, input_name_to_nodes)
57
+
58
+ if len(residual_add_children) != 1 or residual_add_children[0].op_type != "QuantizeLinear":
59
+ return
60
+
61
+ downstream_quantize_node = residual_add_children[0]
62
+
63
+ elif bias_add_child.op_type == "QuantizeLinear":
64
+ downstream_quantize_node = bias_add_child
65
+
66
+ else:
67
+ return
68
+
69
+ # Make sure the downstream QuantizeLinear has the proper zero points and scales
70
+ if not FusionUtils.check_qdq_node_for_fusion(downstream_quantize_node, self.model):
71
+ return
72
+
73
+ # The first input to MatMul should flow through a DequantizeLinear node
74
+ first_path_id, first_input_parent_nodes, _ = self.model.match_parent_paths(
75
+ node,
76
+ [(["DequantizeLinear"], [0])],
77
+ output_name_to_node,
78
+ )
79
+
80
+ # If Attention is not fused, this is the pattern to look for
81
+ # leading upto the MatMul
82
+ reshape_node_0 = None
83
+ transpose_node_0 = None
84
+ if first_path_id < 0:
85
+ first_path_id, first_input_parent_nodes, _ = self.model.match_parent_paths(
86
+ node,
87
+ [(["Reshape", "Transpose", "DequantizeLinear", "QuantizeLinear"], [0, 0, 0, 0])],
88
+ output_name_to_node,
89
+ )
90
+
91
+ if first_path_id < 0:
92
+ return
93
+
94
+ reshape_node_0 = first_input_parent_nodes[0]
95
+ transpose_node_0 = first_input_parent_nodes[1]
96
+ dequantize_node_0 = first_input_parent_nodes[2]
97
+ else:
98
+ dequantize_node_0 = first_input_parent_nodes[0]
99
+
100
+ # Make sure the upstream DequantizeLinear-0 has the proper zero points and scales
101
+ if not FusionUtils.check_qdq_node_for_fusion(dequantize_node_0, self.model):
102
+ return
103
+
104
+ # The second input to MatMul should flow through a DequantizeLinear node
105
+ dequantize_node_1 = None
106
+ is_weight_transpose_required = True
107
+
108
+ weight_path_id, weight_nodes, _ = self.model.match_parent_paths(
109
+ node,
110
+ [(["DequantizeLinear", "QuantizeLinear", "Transpose", "DequantizeLinear"], [1, 0, 0, 0])],
111
+ output_name_to_node,
112
+ )
113
+
114
+ if weight_path_id < 0:
115
+ weight_path_id, weight_nodes, _ = self.model.match_parent_paths(
116
+ node,
117
+ [(["DequantizeLinear"], [1])],
118
+ output_name_to_node,
119
+ )
120
+
121
+ if weight_path_id < 0:
122
+ return
123
+
124
+ dequantize_node_1 = weight_nodes[0]
125
+ else:
126
+ is_weight_transpose_required = False
127
+ dequantize_node_1 = weight_nodes[3]
128
+
129
+ # Check if weight 'B' is a constant
130
+ if self.model.get_constant_value(dequantize_node_1.input[0]) is None:
131
+ return
132
+
133
+ # Make sure the upstream DequantizeLinear-1 has the proper zero points and scales
134
+ # Per-channel scales are supported for weights alone
135
+ if not FusionUtils.check_qdq_node_for_fusion(dequantize_node_1, self.model, False):
136
+ return
137
+
138
+ # Make sure the upstream flow into the Residual Add node flows through a DQ node
139
+ residual_add_dequantize_node = None
140
+
141
+ if residual_add_node is not None:
142
+ residual_path_id, residual_input_parent_nodes, _ = self.model.match_parent_paths(
143
+ residual_add_node,
144
+ [
145
+ (["DequantizeLinear"], [1]),
146
+ ],
147
+ output_name_to_node,
148
+ )
149
+
150
+ if residual_path_id < 0:
151
+ return
152
+
153
+ residual_add_dequantize_node = residual_input_parent_nodes[0]
154
+
155
+ # Make sure the upstream DequantizeLinear to the Residual Add has the proper zero points and scales
156
+ if residual_add_dequantize_node is not None and not FusionUtils.check_qdq_node_for_fusion(
157
+ residual_add_dequantize_node, self.model
158
+ ):
159
+ return
160
+
161
+ # Subgraph nodes to be fused
162
+ subgraph_nodes = [node, bias_add_node] # MatMul + Bias Add
163
+
164
+ if residual_add_node is not None:
165
+ subgraph_nodes.extend([residual_add_node]) # Residual Add
166
+
167
+ subgraph_nodes.extend(weight_nodes)
168
+ subgraph_nodes.extend([downstream_quantize_node]) # Downstream Q node
169
+
170
+ if not self.model.is_safe_to_fuse_nodes(
171
+ subgraph_nodes, downstream_quantize_node.output, input_name_to_nodes, output_name_to_node
172
+ ):
173
+ logger.debug("It is not safe to fuse QOrderedMatMul node. Skip")
174
+ return
175
+
176
+ # Deal with the case where-in the Attention subgraph is not fused
177
+ if transpose_node_0 is not None:
178
+ self.model.replace_node_input(transpose_node_0, transpose_node_0.input[0], dequantize_node_0.input[0])
179
+
180
+ # Make inputs
181
+ fused_node_inputs = [
182
+ reshape_node_0.output[0] if reshape_node_0 is not None else dequantize_node_0.input[0],
183
+ dequantize_node_0.input[1],
184
+ dequantize_node_1.input[0],
185
+ dequantize_node_1.input[1],
186
+ downstream_quantize_node.input[1],
187
+ bias_add_node.input[bias_add_node_index],
188
+ ]
189
+
190
+ if residual_add_node is not None:
191
+ fused_node_inputs.append(residual_add_dequantize_node.input[0])
192
+ fused_node_inputs.append(residual_add_dequantize_node.input[1])
193
+
194
+ # The MatMul weight 'B' and 'bias' need some post-processing
195
+ # Transpose weight 'B' from order ROW to order COL
196
+ # This offline transpose is needed only while using the CUDA EP
197
+ # TODO: Make this fusion logic EP-agnostic ?
198
+ if is_weight_transpose_required:
199
+ weight_tensor = self.model.get_initializer(dequantize_node_1.input[0])
200
+ FusionUtils.transpose_2d_int8_tensor(weight_tensor)
201
+
202
+ fused_node = helper.make_node(
203
+ "QOrderedMatMul",
204
+ inputs=fused_node_inputs,
205
+ outputs=[downstream_quantize_node.output[0]],
206
+ name=self.model.create_node_name("QOrderedMatMul", name_prefix="QOrderedMatMul"),
207
+ )
208
+
209
+ fused_node.attribute.extend([helper.make_attribute("order_A", 1)])
210
+ fused_node.attribute.extend([helper.make_attribute("order_B", 0)])
211
+ fused_node.attribute.extend([helper.make_attribute("order_Y", 1)])
212
+
213
+ fused_node.domain = "com.microsoft"
214
+
215
+ self.nodes_to_remove.extend(subgraph_nodes)
216
+ self.nodes_to_add.append(fused_node)
217
+ self.node_name_to_graph_name[fused_node.name] = self.this_graph_name
@@ -0,0 +1,74 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+
6
+ import logging
7
+
8
+ from fusion_base import Fusion
9
+ from onnx import helper
10
+ from onnx_model import OnnxModel
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ class FusionQuickGelu(Fusion):
16
+ def __init__(self, model: OnnxModel):
17
+ super().__init__(model, "QuickGelu", ["Mul"])
18
+
19
+ def fuse(self, node, input_name_to_nodes, output_name_to_node):
20
+ # Fuse the following subgraph to `QuickGelu`
21
+ #
22
+ # root_input
23
+ # / \
24
+ # | Mul ----+
25
+ # | (B = ~1.702) |
26
+ # \ | |
27
+ # \ Sigmoid |---- `QuickGelu`
28
+ # \ / |
29
+ # \ / |
30
+ # Mul ----+
31
+ # |
32
+ # root_output
33
+
34
+ if node.op_type != "Mul":
35
+ logger.debug("fuse_quickgelu: failed to match second Mul node")
36
+ return
37
+
38
+ second_mul_node = node
39
+ root_input = second_mul_node.input[0]
40
+
41
+ sigmoid_node = self.model.match_parent_path(second_mul_node, ["Sigmoid"], [1])
42
+ if sigmoid_node is None:
43
+ logger.debug("fuse_quickgelu: failed to match Sigmoid node")
44
+ return
45
+ sigmoid_node = sigmoid_node[0]
46
+
47
+ first_mul_node = self.model.match_parent_path(sigmoid_node, ["Mul"], [0])
48
+ if first_mul_node is None:
49
+ logger.debug("fuse_quickgelu: failed to match first Mul node")
50
+ return
51
+ first_mul_node = first_mul_node[0]
52
+
53
+ approximation_value = self.model.get_constant_value(first_mul_node.input[1]).item()
54
+ if abs(approximation_value - 1.7021484375) >= 1e-3:
55
+ logger.debug("fuse_quickgelu: failed to match approximation value")
56
+ return
57
+
58
+ if first_mul_node.input[0] != root_input:
59
+ logger.debug("fuse_quickgelu: failed to match root input with first Mul node's input")
60
+ return
61
+
62
+ new_node = helper.make_node(
63
+ "QuickGelu",
64
+ inputs=[root_input],
65
+ outputs=[second_mul_node.output[0]],
66
+ name=self.model.create_node_name("QuickGelu"),
67
+ )
68
+ new_node.domain = "com.microsoft"
69
+ new_node.attribute.extend([helper.make_attribute("alpha", approximation_value)])
70
+
71
+ self.nodes_to_remove.extend([first_mul_node, sigmoid_node, second_mul_node])
72
+ self.nodes_to_add.append(new_node)
73
+ self.node_name_to_graph_name[new_node.name] = self.this_graph_name
74
+ self.increase_counter("QuickGelu")
@@ -0,0 +1,173 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+
6
+ from logging import getLogger
7
+
8
+ import numpy as np
9
+ from fusion_base import Fusion
10
+ from onnx import TensorProto, helper
11
+ from onnx_model import OnnxModel
12
+
13
+ logger = getLogger(__name__)
14
+
15
+
16
+ class FusionReshape(Fusion):
17
+ def __init__(self, model: OnnxModel):
18
+ super().__init__(model, "Reshape", "Reshape")
19
+ self.prune_graph: bool = False
20
+
21
+ def replace_reshape_node(self, shape, reshape_node, concat_node):
22
+ shape_value = np.asarray(shape, dtype=np.int64)
23
+ constant_shape_name = self.model.create_node_name("Constant", "constant_shape")
24
+ new_node = helper.make_node(
25
+ "Constant",
26
+ inputs=[],
27
+ outputs=[constant_shape_name],
28
+ value=helper.make_tensor(
29
+ name="const_tensor",
30
+ data_type=TensorProto.INT64,
31
+ dims=shape_value.shape,
32
+ vals=bytes(shape_value),
33
+ raw=True,
34
+ ),
35
+ )
36
+ reshape_node.input[1] = constant_shape_name
37
+ reshape_node.name = self.model.create_node_name("Reshape", "Reshape_Fuse")
38
+ self.nodes_to_remove.extend([concat_node])
39
+ self.nodes_to_add.append(new_node)
40
+ self.node_name_to_graph_name[new_node.name] = self.this_graph_name
41
+
42
+ def fuse(self, reshape_node, input_name_to_nodes, output_name_to_node):
43
+ if reshape_node.input[1] not in output_name_to_node:
44
+ return
45
+
46
+ concat_node = output_name_to_node[reshape_node.input[1]]
47
+ if concat_node.op_type != "Concat" or len(concat_node.input) < 3 or len(concat_node.input) > 4:
48
+ return
49
+
50
+ path0 = self.model.match_parent_path(
51
+ concat_node,
52
+ ["Unsqueeze", "Gather", "Shape"],
53
+ [0, 0, 0],
54
+ output_name_to_node,
55
+ )
56
+ if path0 is None:
57
+ return
58
+
59
+ (unsqueeze_0, gather_0, shape_0) = path0
60
+
61
+ path1 = self.model.match_parent_path(
62
+ concat_node,
63
+ ["Unsqueeze", "Gather", "Shape"],
64
+ [1, 0, 0],
65
+ output_name_to_node,
66
+ )
67
+ if path1 is None:
68
+ return
69
+ (unsqueeze_1, gather_1, shape_1) = path1
70
+
71
+ shape = []
72
+ gather_value = self.model.get_constant_value(gather_0.input[1])
73
+ if gather_value == 0:
74
+ shape.append(0)
75
+
76
+ gather_value = self.model.get_constant_value(gather_1.input[1])
77
+ if gather_value == 1:
78
+ shape.append(0)
79
+
80
+ if len(shape) != 2:
81
+ return
82
+
83
+ path2 = []
84
+ path3 = []
85
+ shape_nodes = [shape_0, shape_1]
86
+ if len(concat_node.input) == 3 and self.model.get_constant_value(concat_node.input[2]) is None:
87
+ path2 = self.model.match_parent_path(
88
+ concat_node,
89
+ ["Unsqueeze", "Mul", "Gather", "Shape"],
90
+ [2, 0, 0, 0],
91
+ output_name_to_node,
92
+ )
93
+ if path2 is None:
94
+ path2 = self.model.match_parent_path(
95
+ concat_node,
96
+ ["Unsqueeze", "Mul", "Squeeze", "Slice", "Shape"],
97
+ [2, 0, 0, 0, 0],
98
+ output_name_to_node,
99
+ ) # GPT2 exported by PyTorch 1.4 with opset_version=11
100
+ if path2 is None:
101
+ return
102
+
103
+ path3 = self.model.match_parent_path(
104
+ concat_node,
105
+ ["Unsqueeze", "Mul", "Gather", "Shape"],
106
+ [2, 0, 1, 0],
107
+ output_name_to_node,
108
+ )
109
+ if path3 is None:
110
+ path3 = self.model.match_parent_path(
111
+ concat_node,
112
+ ["Unsqueeze", "Mul", "Squeeze", "Slice", "Shape"],
113
+ [2, 0, 1, 0, 0],
114
+ output_name_to_node,
115
+ ) # GPT2 exported by PyTorch 1.4 with opset_version=11
116
+ if path3 is None:
117
+ return
118
+
119
+ shape_nodes.extend([path2[-1], path3[-1]])
120
+ shape.append(-1)
121
+ elif len(concat_node.input) > 2:
122
+ concat_value = self.model.get_constant_value(concat_node.input[2])
123
+ if concat_value is None:
124
+ return
125
+ if isinstance(concat_value, np.ndarray):
126
+ shape.extend(concat_value.tolist())
127
+ else:
128
+ shape.append(concat_value)
129
+
130
+ if len(concat_node.input) == 4 and self.model.get_constant_value(concat_node.input[3]) is None:
131
+ if -1 in shape:
132
+ return
133
+
134
+ path2 = self.model.match_parent_path(
135
+ concat_node,
136
+ ["Unsqueeze", "Div", "Gather", "Shape"],
137
+ [3, 0, 0, 0],
138
+ output_name_to_node,
139
+ )
140
+ if path2 is None:
141
+ path2 = self.model.match_parent_path(
142
+ concat_node,
143
+ ["Unsqueeze", "Div", "Squeeze", "Slice", "Shape"],
144
+ [3, 0, 0, 0, 0],
145
+ output_name_to_node,
146
+ ) # GPT2 exported by PyTorch 1.4 with opset_version=11
147
+ if path2 is None:
148
+ return
149
+ shape_nodes.extend([path2[-1]])
150
+ shape.append(-1)
151
+ elif len(concat_node.input) > 3:
152
+ concat_value = self.model.get_constant_value(concat_node.input[3])
153
+ if concat_value is None:
154
+ return
155
+
156
+ if isinstance(concat_value, np.ndarray):
157
+ shape.extend(concat_value.tolist())
158
+ else:
159
+ shape.append(concat_value)
160
+
161
+ root_input = reshape_node.input[0]
162
+ same_shape_input = True
163
+ for shape_node in shape_nodes:
164
+ if shape_node.input[0] != root_input:
165
+ same_shape_input = False
166
+
167
+ if not same_shape_input:
168
+ return
169
+
170
+ self.replace_reshape_node(shape, reshape_node, concat_node)
171
+
172
+ # TODO(tlwu): Subgraph blocks pruning un-used nodes. Add code to remove un-used nodes safely.
173
+ self.prune_graph = True