onnxruntime-directml 1.20.0__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (305) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6508 -0
  4. onnxruntime/__init__.py +78 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +174 -0
  7. onnxruntime/backend/backend_rep.py +53 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  13. onnxruntime/capi/onnxruntime.dll +0 -0
  14. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  15. onnxruntime/capi/onnxruntime_inference_collection.py +1108 -0
  16. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  17. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  18. onnxruntime/capi/onnxruntime_validation.py +150 -0
  19. onnxruntime/capi/version_info.py +2 -0
  20. onnxruntime/datasets/__init__.py +17 -0
  21. onnxruntime/datasets/logreg_iris.onnx +0 -0
  22. onnxruntime/datasets/mul_1.onnx +0 -0
  23. onnxruntime/datasets/sigmoid.onnx +13 -0
  24. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  27. onnxruntime/quantization/__init__.py +16 -0
  28. onnxruntime/quantization/base_quantizer.py +532 -0
  29. onnxruntime/quantization/calibrate.py +1245 -0
  30. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  31. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  32. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  33. onnxruntime/quantization/execution_providers/qnn/preprocess.py +307 -0
  34. onnxruntime/quantization/execution_providers/qnn/quant_config.py +387 -0
  35. onnxruntime/quantization/fusions/__init__.py +3 -0
  36. onnxruntime/quantization/fusions/fusion.py +311 -0
  37. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  38. onnxruntime/quantization/fusions/fusion_layernorm.py +135 -0
  39. onnxruntime/quantization/matmul_4bits_quantizer.py +1480 -0
  40. onnxruntime/quantization/matmul_bnb4_quantizer.py +240 -0
  41. onnxruntime/quantization/onnx_model.py +580 -0
  42. onnxruntime/quantization/onnx_quantizer.py +1008 -0
  43. onnxruntime/quantization/operators/__init__.py +2 -0
  44. onnxruntime/quantization/operators/activation.py +119 -0
  45. onnxruntime/quantization/operators/argmax.py +18 -0
  46. onnxruntime/quantization/operators/attention.py +73 -0
  47. onnxruntime/quantization/operators/base_operator.py +26 -0
  48. onnxruntime/quantization/operators/binary_op.py +72 -0
  49. onnxruntime/quantization/operators/concat.py +62 -0
  50. onnxruntime/quantization/operators/conv.py +258 -0
  51. onnxruntime/quantization/operators/direct_q8.py +78 -0
  52. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  53. onnxruntime/quantization/operators/gather.py +64 -0
  54. onnxruntime/quantization/operators/gavgpool.py +62 -0
  55. onnxruntime/quantization/operators/gemm.py +166 -0
  56. onnxruntime/quantization/operators/lstm.py +117 -0
  57. onnxruntime/quantization/operators/matmul.py +231 -0
  58. onnxruntime/quantization/operators/maxpool.py +34 -0
  59. onnxruntime/quantization/operators/norm.py +40 -0
  60. onnxruntime/quantization/operators/pad.py +100 -0
  61. onnxruntime/quantization/operators/pooling.py +67 -0
  62. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  63. onnxruntime/quantization/operators/resize.py +34 -0
  64. onnxruntime/quantization/operators/softmax.py +74 -0
  65. onnxruntime/quantization/operators/split.py +63 -0
  66. onnxruntime/quantization/operators/where.py +87 -0
  67. onnxruntime/quantization/preprocess.py +141 -0
  68. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  69. onnxruntime/quantization/qdq_quantizer.py +1187 -0
  70. onnxruntime/quantization/quant_utils.py +891 -0
  71. onnxruntime/quantization/quantize.py +748 -0
  72. onnxruntime/quantization/registry.py +106 -0
  73. onnxruntime/quantization/shape_inference.py +187 -0
  74. onnxruntime/quantization/tensor_quant_overrides.py +516 -0
  75. onnxruntime/tools/__init__.py +10 -0
  76. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  77. onnxruntime/tools/convert_onnx_models_to_ort.py +377 -0
  78. onnxruntime/tools/file_utils.py +46 -0
  79. onnxruntime/tools/logger.py +11 -0
  80. onnxruntime/tools/make_dynamic_shape_fixed.py +72 -0
  81. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  82. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +33 -0
  83. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  84. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  85. onnxruntime/tools/mobile_helpers/usability_checker.py +739 -0
  86. onnxruntime/tools/offline_tuning.py +169 -0
  87. onnxruntime/tools/onnx_model_utils.py +413 -0
  88. onnxruntime/tools/onnx_randomizer.py +85 -0
  89. onnxruntime/tools/onnxruntime_test.py +164 -0
  90. onnxruntime/tools/optimize_onnx_model.py +55 -0
  91. onnxruntime/tools/ort_format_model/__init__.py +25 -0
  92. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +663 -0
  93. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  94. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  95. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  96. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  97. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  98. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  99. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  100. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  140. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  141. onnxruntime/tools/ort_format_model/types.py +84 -0
  142. onnxruntime/tools/ort_format_model/utils.py +62 -0
  143. onnxruntime/tools/pytorch_export_contrib_ops.py +108 -0
  144. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  145. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  146. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  147. onnxruntime/tools/reduced_build_config_parser.py +202 -0
  148. onnxruntime/tools/symbolic_shape_infer.py +3016 -0
  149. onnxruntime/tools/update_onnx_opset.py +31 -0
  150. onnxruntime/transformers/__init__.py +8 -0
  151. onnxruntime/transformers/affinity_helper.py +40 -0
  152. onnxruntime/transformers/benchmark.py +944 -0
  153. onnxruntime/transformers/benchmark_helper.py +646 -0
  154. onnxruntime/transformers/bert_perf_test.py +634 -0
  155. onnxruntime/transformers/bert_test_data.py +642 -0
  156. onnxruntime/transformers/compare_bert_results.py +246 -0
  157. onnxruntime/transformers/constants.py +47 -0
  158. onnxruntime/transformers/convert_generation.py +3124 -0
  159. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  160. onnxruntime/transformers/convert_to_packing_mode.py +387 -0
  161. onnxruntime/transformers/dynamo_onnx_helper.py +104 -0
  162. onnxruntime/transformers/float16.py +501 -0
  163. onnxruntime/transformers/fusion_attention.py +1235 -0
  164. onnxruntime/transformers/fusion_attention_clip.py +257 -0
  165. onnxruntime/transformers/fusion_attention_sam2.py +534 -0
  166. onnxruntime/transformers/fusion_attention_unet.py +1304 -0
  167. onnxruntime/transformers/fusion_attention_vae.py +301 -0
  168. onnxruntime/transformers/fusion_bart_attention.py +640 -0
  169. onnxruntime/transformers/fusion_base.py +137 -0
  170. onnxruntime/transformers/fusion_bias_add.py +58 -0
  171. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  172. onnxruntime/transformers/fusion_biassplitgelu.py +111 -0
  173. onnxruntime/transformers/fusion_conformer_attention.py +143 -0
  174. onnxruntime/transformers/fusion_embedlayer.py +811 -0
  175. onnxruntime/transformers/fusion_fastgelu.py +360 -0
  176. onnxruntime/transformers/fusion_gelu.py +259 -0
  177. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  178. onnxruntime/transformers/fusion_gemmfastgelu.py +122 -0
  179. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  180. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  181. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  182. onnxruntime/transformers/fusion_group_norm.py +179 -0
  183. onnxruntime/transformers/fusion_layernorm.py +465 -0
  184. onnxruntime/transformers/fusion_nhwc_conv.py +100 -0
  185. onnxruntime/transformers/fusion_options.py +340 -0
  186. onnxruntime/transformers/fusion_qordered_attention.py +421 -0
  187. onnxruntime/transformers/fusion_qordered_gelu.py +119 -0
  188. onnxruntime/transformers/fusion_qordered_layernorm.py +123 -0
  189. onnxruntime/transformers/fusion_qordered_matmul.py +217 -0
  190. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  191. onnxruntime/transformers/fusion_reshape.py +173 -0
  192. onnxruntime/transformers/fusion_rotary_attention.py +1592 -0
  193. onnxruntime/transformers/fusion_shape.py +110 -0
  194. onnxruntime/transformers/fusion_simplified_layernorm.py +159 -0
  195. onnxruntime/transformers/fusion_skip_group_norm.py +255 -0
  196. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  197. onnxruntime/transformers/fusion_transpose.py +168 -0
  198. onnxruntime/transformers/fusion_utils.py +307 -0
  199. onnxruntime/transformers/huggingface_models.py +167 -0
  200. onnxruntime/transformers/import_utils.py +20 -0
  201. onnxruntime/transformers/io_binding_helper.py +442 -0
  202. onnxruntime/transformers/large_model_exporter.py +395 -0
  203. onnxruntime/transformers/machine_info.py +221 -0
  204. onnxruntime/transformers/metrics.py +164 -0
  205. onnxruntime/transformers/models/bart/__init__.py +12 -0
  206. onnxruntime/transformers/models/bart/export.py +98 -0
  207. onnxruntime/transformers/models/bert/__init__.py +12 -0
  208. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  209. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  210. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  211. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +561 -0
  212. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1032 -0
  213. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  214. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  215. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  216. onnxruntime/transformers/models/llama/__init__.py +12 -0
  217. onnxruntime/transformers/models/llama/benchmark.py +703 -0
  218. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  219. onnxruntime/transformers/models/llama/benchmark_e2e.py +606 -0
  220. onnxruntime/transformers/models/llama/convert_to_onnx.py +1027 -0
  221. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  222. onnxruntime/transformers/models/llama/llama_inputs.py +503 -0
  223. onnxruntime/transformers/models/llama/llama_parity.py +309 -0
  224. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  225. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  226. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  227. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  228. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  229. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  230. onnxruntime/transformers/models/longformer/longformer_helper.py +77 -0
  231. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  232. onnxruntime/transformers/models/phi2/convert_to_onnx.py +576 -0
  233. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  234. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  235. onnxruntime/transformers/models/sam2/benchmark_sam2.py +625 -0
  236. onnxruntime/transformers/models/sam2/convert_to_onnx.py +260 -0
  237. onnxruntime/transformers/models/sam2/image_decoder.py +273 -0
  238. onnxruntime/transformers/models/sam2/image_encoder.py +186 -0
  239. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  240. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  241. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  242. onnxruntime/transformers/models/sam2/sam2_demo.py +322 -0
  243. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +280 -0
  244. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  245. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  246. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1429 -0
  247. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  248. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +102 -0
  249. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +268 -0
  250. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  251. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1319 -0
  252. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1181 -0
  253. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +296 -0
  254. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +388 -0
  255. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  256. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  257. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  258. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +350 -0
  259. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  260. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  261. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  262. onnxruntime/transformers/models/t5/__init__.py +12 -0
  263. onnxruntime/transformers/models/t5/convert_to_onnx.py +278 -0
  264. onnxruntime/transformers/models/t5/past_helper.py +150 -0
  265. onnxruntime/transformers/models/t5/t5_decoder.py +438 -0
  266. onnxruntime/transformers/models/t5/t5_encoder.py +171 -0
  267. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +299 -0
  268. onnxruntime/transformers/models/t5/t5_helper.py +272 -0
  269. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  270. onnxruntime/transformers/models/whisper/benchmark.py +610 -0
  271. onnxruntime/transformers/models/whisper/benchmark_all.py +528 -0
  272. onnxruntime/transformers/models/whisper/convert_to_onnx.py +536 -0
  273. onnxruntime/transformers/models/whisper/whisper_chain.py +329 -0
  274. onnxruntime/transformers/models/whisper/whisper_decoder.py +402 -0
  275. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  276. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +306 -0
  277. onnxruntime/transformers/models/whisper/whisper_helper.py +524 -0
  278. onnxruntime/transformers/models/whisper/whisper_openai_helper.py +84 -0
  279. onnxruntime/transformers/onnx_exporter.py +717 -0
  280. onnxruntime/transformers/onnx_model.py +1569 -0
  281. onnxruntime/transformers/onnx_model_bart.py +142 -0
  282. onnxruntime/transformers/onnx_model_bert.py +481 -0
  283. onnxruntime/transformers/onnx_model_bert_keras.py +475 -0
  284. onnxruntime/transformers/onnx_model_bert_tf.py +589 -0
  285. onnxruntime/transformers/onnx_model_clip.py +40 -0
  286. onnxruntime/transformers/onnx_model_conformer.py +33 -0
  287. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  288. onnxruntime/transformers/onnx_model_phi.py +930 -0
  289. onnxruntime/transformers/onnx_model_sam2.py +138 -0
  290. onnxruntime/transformers/onnx_model_t5.py +791 -0
  291. onnxruntime/transformers/onnx_model_tnlr.py +227 -0
  292. onnxruntime/transformers/onnx_model_unet.py +259 -0
  293. onnxruntime/transformers/onnx_model_vae.py +43 -0
  294. onnxruntime/transformers/onnx_utils.py +55 -0
  295. onnxruntime/transformers/optimizer.py +612 -0
  296. onnxruntime/transformers/profiler.py +725 -0
  297. onnxruntime/transformers/quantize_helper.py +76 -0
  298. onnxruntime/transformers/shape_infer_helper.py +122 -0
  299. onnxruntime/transformers/shape_optimizer.py +401 -0
  300. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  301. onnxruntime_directml-1.20.0.dist-info/METADATA +187 -0
  302. onnxruntime_directml-1.20.0.dist-info/RECORD +305 -0
  303. onnxruntime_directml-1.20.0.dist-info/WHEEL +5 -0
  304. onnxruntime_directml-1.20.0.dist-info/entry_points.txt +2 -0
  305. onnxruntime_directml-1.20.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,930 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+
6
+ from logging import getLogger
7
+ from typing import List, Optional
8
+
9
+ import numpy as np
10
+ from dynamo_onnx_helper import DynamoOnnxHelper
11
+ from fusion_base import Fusion
12
+ from fusion_options import AttentionOpType, FusionOptions
13
+ from fusion_skiplayernorm import FusionBiasSkipLayerNormalization, FusionSkipLayerNormalization
14
+ from fusion_utils import NumpyHelper
15
+ from onnx import ModelProto, NodeProto, TensorProto, helper, numpy_helper
16
+ from onnx_model import OnnxModel
17
+
18
+ logger = getLogger(__name__)
19
+
20
+
21
+ class ProcessGemmWFunc:
22
+ def __call__(self, x):
23
+ return np.transpose(x, (1, 0))
24
+
25
+
26
+ class ProcessMatMulQFunc:
27
+ def __call__(self, x):
28
+ return np.transpose(np.split(x, 3, 0)[0], (1, 0))
29
+
30
+
31
+ class ProcessMatMulKFunc:
32
+ def __call__(self, x):
33
+ return np.transpose(np.split(x, 3, 0)[1], (1, 0))
34
+
35
+
36
+ class ProcessMatMulVFunc:
37
+ def __call__(self, x):
38
+ return np.transpose(np.split(x, 3, 0)[2], (1, 0))
39
+
40
+
41
+ class ProcessBiasQFunc:
42
+ def __call__(self, x):
43
+ x = np.split(x, 3, -1)[0]
44
+ return x
45
+
46
+
47
+ class ProcessBiasKFunc:
48
+ def __call__(self, x):
49
+ x = np.split(x, 3, -1)[1]
50
+ return x
51
+
52
+
53
+ class ProcessBiasVFunc:
54
+ def __call__(self, x):
55
+ x = np.split(x, 3, -1)[2]
56
+ return x
57
+
58
+
59
+ class ProcessRotCacheFunc:
60
+ def __call__(self, x):
61
+ # half rotary embedding
62
+ assert len(x.shape) == 2
63
+ if x.shape[1] == 32:
64
+ return x[:, 0:16]
65
+ return x
66
+
67
+
68
+ # TODO: move to a separate file
69
+ class Fission(Fusion):
70
+ def __init__(
71
+ self,
72
+ model: OnnxModel,
73
+ nodes_to_find: List[str],
74
+ ):
75
+ super().__init__(model, "DONOTUSE", nodes_to_find)
76
+
77
+ def set_attention_op_type(self, attn_op_type: AttentionOpType):
78
+ self.attn_op_type = attn_op_type
79
+
80
+ def get_uname(self, layer_id, name):
81
+ return name + "_" + str(layer_id)
82
+
83
+ def get_edge_by_name(self, edges, name):
84
+ for edge in edges:
85
+ if edge == name or edge.endswith(name) or edge.startswith(name):
86
+ return edge
87
+ raise ValueError(f"Edge {name} not found")
88
+
89
+ def get_input_by_name(self, node, name):
90
+ return self.get_edge_by_name(node.input, name)
91
+
92
+ def get_output_by_name(self, node, name):
93
+ return self.get_edge_by_name(node.output, name)
94
+
95
+ def process_initializer(self, initializer_name, functor, custom_name=None):
96
+ i = self.model.get_initializer(initializer_name)
97
+ i_np_array = NumpyHelper.to_array(i)
98
+ processed_i_np_array = functor(i_np_array)
99
+ new_tensor = helper.make_tensor(
100
+ initializer_name + "_processed" if custom_name is None else custom_name,
101
+ data_type=TensorProto.FLOAT,
102
+ dims=processed_i_np_array.shape,
103
+ vals=processed_i_np_array.flatten().tobytes(),
104
+ raw=True,
105
+ )
106
+ self.model.add_initializer(new_tensor, self.this_graph_name)
107
+ return new_tensor.name
108
+
109
+ def add_fp32_value_info(self, name):
110
+ new_value_info = self.model.graph().value_info.add()
111
+ new_value_info.name = name
112
+ new_value_info.type.tensor_type.elem_type = TensorProto.FLOAT
113
+
114
+ def add_int64_value_info(self, name):
115
+ new_value_info = self.model.graph().value_info.add()
116
+ new_value_info.name = name
117
+ new_value_info.type.tensor_type.elem_type = TensorProto.INT64
118
+
119
+ def replace_fp32_value_info(self, name, shape):
120
+ for value_info in self.model.graph().value_info:
121
+ if value_info.name == name:
122
+ self.model.graph().value_info.remove(value_info)
123
+ break
124
+ new_value_info = helper.make_tensor_value_info(
125
+ name,
126
+ elem_type=TensorProto.FLOAT,
127
+ shape=shape,
128
+ )
129
+ self.model.graph().value_info.extend([new_value_info])
130
+
131
+ def set_unique_name_and_add_nodes(
132
+ self, subgraph_nodes: List[NodeProto], layer_id: int, layer_known_edges_names: List[str]
133
+ ):
134
+ for new_node in subgraph_nodes:
135
+ for i, name in enumerate(new_node.input):
136
+ if name == "":
137
+ continue
138
+ elif name not in layer_known_edges_names:
139
+ new_node.input[i] = self.get_uname(layer_id, name)
140
+ self.add_fp32_value_info(new_node.input[i])
141
+ for i, name in enumerate(new_node.output):
142
+ if name == "":
143
+ continue
144
+ elif name not in layer_known_edges_names:
145
+ new_node.output[i] = self.get_uname(layer_id, name)
146
+ self.add_fp32_value_info(new_node.output[i])
147
+ new_node.name = self.get_uname(layer_id, new_node.name)
148
+ self.nodes_to_add.append(new_node)
149
+ self.node_name_to_graph_name[new_node.name] = self.this_graph_name
150
+
151
+ def layernorm(self, inputs: List[str], outputs: List[str], prefix: str = ""):
152
+ assert len(inputs) == 3
153
+ assert len(outputs) == 1
154
+ node = helper.make_node(
155
+ "LayerNormalization",
156
+ inputs=inputs,
157
+ outputs=outputs,
158
+ name=prefix + "_LayerNormalization",
159
+ epsilon=9.999999747378752e-06,
160
+ )
161
+ return [node]
162
+
163
+ def gemm(self, inputs: List[str], outputs: List[str], prefix: str = ""):
164
+ assert len(inputs) == 3
165
+ assert len(outputs) == 1
166
+ matmul = helper.make_node(
167
+ "MatMul",
168
+ inputs=[inputs[0], inputs[1]],
169
+ outputs=[prefix + "matmul_out"],
170
+ name=prefix + "MatMul",
171
+ )
172
+ add = helper.make_node(
173
+ "Add",
174
+ inputs=[prefix + "matmul_out", inputs[2]],
175
+ outputs=outputs,
176
+ name=prefix + "Bias",
177
+ )
178
+ return [matmul, add]
179
+
180
+ def rotary(self, inputs: List[str], outputs: List[str], prefix: str = "", rot_dim=32, num_heads=32):
181
+ assert len(inputs) == 4
182
+ assert len(outputs) == 1
183
+ node = helper.make_node(
184
+ "RotaryEmbedding",
185
+ inputs=inputs,
186
+ outputs=outputs,
187
+ name=prefix + "RotaryEmbedding",
188
+ domain="com.microsoft",
189
+ rotary_embedding_dim=rot_dim,
190
+ num_heads=num_heads,
191
+ )
192
+ return [node]
193
+
194
+ def fastgelu(self, inputs: List[str], outputs: List[str], prefix: str = ""):
195
+ assert len(inputs) == 1
196
+ assert len(outputs) == 1
197
+ node = helper.make_node(
198
+ "FastGelu",
199
+ inputs=inputs,
200
+ outputs=outputs,
201
+ name=prefix + "FastGelu",
202
+ domain="com.microsoft",
203
+ )
204
+ return [node]
205
+
206
+ def add(self, inputs: List[str], outputs: List[str], prefix: str = ""):
207
+ assert len(inputs) == 2
208
+ assert len(outputs) == 1
209
+ node = helper.make_node(
210
+ "Add",
211
+ inputs=inputs,
212
+ outputs=outputs,
213
+ name=prefix + "Add",
214
+ )
215
+ return [node]
216
+
217
+ def mha(self, inputs: List[str], outputs: List[str], prefix: str = "", num_heads=32):
218
+ assert len(inputs) == 8
219
+ assert len(outputs) == 3
220
+ node = helper.make_node(
221
+ "MultiHeadAttention",
222
+ inputs=inputs,
223
+ outputs=outputs,
224
+ name=prefix + "MultiHeadAttention",
225
+ domain="com.microsoft",
226
+ num_heads=num_heads,
227
+ unidirectional=1,
228
+ )
229
+ return [node]
230
+
231
+ def gqa(self, inputs: List[str], outputs: List[str], prefix: str = "", num_heads=32):
232
+ assert len(inputs) == 7
233
+ assert len(outputs) == 3
234
+ node = helper.make_node(
235
+ "GroupQueryAttention",
236
+ inputs=inputs,
237
+ outputs=outputs,
238
+ name=prefix + "GroupQueryAttention",
239
+ domain="com.microsoft",
240
+ num_heads=num_heads,
241
+ kv_num_heads=num_heads,
242
+ )
243
+ return [node]
244
+
245
+ def attention(self, inputs: List[str], outputs: List[str], prefix: str = "", num_heads=32):
246
+ assert len(inputs) == 5
247
+ assert len(outputs) == 2
248
+ node = helper.make_node(
249
+ "Attention",
250
+ inputs=inputs,
251
+ outputs=outputs,
252
+ name=prefix + "Attention",
253
+ domain="com.microsoft",
254
+ num_heads=num_heads,
255
+ unidirectional=1,
256
+ do_rotary=1,
257
+ rotary_embedding_dim=32,
258
+ )
259
+ return [node]
260
+
261
+ def paged_attn(
262
+ self,
263
+ inputs: List[str],
264
+ outputs: List[str],
265
+ prefix: str = "",
266
+ num_heads=32,
267
+ head_size=80,
268
+ scale=0.11180339753627777,
269
+ ):
270
+ assert len(inputs) == 6
271
+ assert len(outputs) == 1
272
+ node = helper.make_node(
273
+ "PagedAttention",
274
+ inputs=inputs,
275
+ outputs=outputs,
276
+ name=prefix + "PagedAttention",
277
+ domain="vllm.ort.ext",
278
+ num_heads=num_heads,
279
+ num_kv_heads=num_heads,
280
+ head_size=head_size,
281
+ scale=scale,
282
+ )
283
+ return [node]
284
+
285
+
286
+ class Phi2PreProcessor(DynamoOnnxHelper):
287
+ def __init__(self, model: ModelProto, num_heads: int, hidden_size: int):
288
+ super().__init__(model)
289
+ self.num_hidden_layers = 32
290
+ self.num_attention_heads = num_heads
291
+ self.hidden_size = hidden_size
292
+
293
+ self.func_name = "modeling_phi_PhiModel_model_1"
294
+
295
+ def get_phi2_edge_dict(self) -> dict:
296
+ edge_dict = {}
297
+ edge_dict["lm_head_1"] = "logits"
298
+ edge_dict["l_input_ids_"] = "input_ids"
299
+ edge_dict["key_states"] = "past_key_0"
300
+ edge_dict["value_states"] = "past_value_0"
301
+ for i in range(1, self.num_hidden_layers, 1):
302
+ edge_dict[f"key_states_{i}"] = f"past_key_{i}"
303
+ edge_dict[f"value_states_{i}"] = f"past_value_{i}"
304
+ edge_dict[f"model_layers_{i}_1"] = f"present_key_{i}"
305
+ edge_dict[f"model_layers_{i}_1_1"] = f"present_value_{i}"
306
+
307
+ outputs = [o.name for o in self.model.graph.output]
308
+ if "model_layers_0_1_1" in outputs and "model_layers_0_1_2" in outputs:
309
+ edge_dict["model_layers_0_1_1"] = "present_key_0"
310
+ edge_dict["model_layers_0_1_2"] = "present_value_0"
311
+ else:
312
+ assert "model_layers_0_1" in outputs and "model_layers_0_1_1" in outputs
313
+ edge_dict["model_layers_0_1"] = "present_key_0"
314
+ edge_dict["model_layers_0_1_1"] = "present_value_0"
315
+ return edge_dict
316
+
317
+ def simplify_phi2_op_type(self):
318
+ phi2_transformer_layer_name = "modeling_phi_PhiDecoderLayer_model_layers"
319
+ for node in self.model.graph.node:
320
+ index = node.op_type.find(phi2_transformer_layer_name)
321
+ if index != -1:
322
+ node.op_type = node.op_type[index:]
323
+
324
+ def process_graph_io(self, attn_op_type: AttentionOpType):
325
+ self.use_attn = attn_op_type == AttentionOpType.Attention
326
+ self.use_vllm = attn_op_type == AttentionOpType.PagedAttention
327
+ graph = self.model.graph
328
+ new_inputs = []
329
+ for vi in graph.input:
330
+ if "input_ids" in vi.name:
331
+ vi_iid = helper.make_tensor_value_info(
332
+ vi.name,
333
+ elem_type=TensorProto.INT32 if not self.use_vllm else TensorProto.INT64,
334
+ shape=["batch_size", "seq_len"],
335
+ )
336
+ vi_step = helper.make_tensor_value_info(
337
+ "step",
338
+ elem_type=TensorProto.INT64,
339
+ shape=[1],
340
+ )
341
+ vi_pid = helper.make_tensor_value_info(
342
+ "position_ids",
343
+ elem_type=TensorProto.INT64,
344
+ shape=["batch_size", "seq_len"],
345
+ )
346
+ vi_mask = helper.make_tensor_value_info(
347
+ "attention_mask",
348
+ elem_type=TensorProto.INT32,
349
+ shape=["batch_size", "seq_len"],
350
+ )
351
+ vi_meta = helper.make_tensor_value_info(
352
+ "input_metadata",
353
+ elem_type=TensorProto.INT64,
354
+ shape=[1],
355
+ )
356
+ (
357
+ new_inputs.extend([vi_iid, vi_step, vi_mask])
358
+ if not self.use_vllm
359
+ else new_inputs.extend([vi_iid, vi_pid, vi_meta])
360
+ )
361
+ if self.use_attn:
362
+ if "past_key" in vi.name:
363
+ vi_cache = helper.make_tensor_value_info(
364
+ vi.name.replace("past_key", "past"),
365
+ elem_type=vi.type.tensor_type.elem_type,
366
+ shape=[
367
+ 2,
368
+ "batch_size",
369
+ self.num_attention_heads,
370
+ "past_seq_len",
371
+ self.hidden_size // self.num_attention_heads,
372
+ ],
373
+ )
374
+ new_inputs.extend([vi_cache])
375
+ elif self.use_vllm:
376
+ if "past_key" in vi.name:
377
+ vi_cache = helper.make_tensor_value_info(
378
+ vi.name,
379
+ elem_type=vi.type.tensor_type.elem_type,
380
+ shape=["num_blocks", "num_heads", "head_size_x", "block_size", "block_x"],
381
+ )
382
+ new_inputs.extend([vi_cache])
383
+ if "past_value" in vi.name:
384
+ vi_cache = helper.make_tensor_value_info(
385
+ vi.name,
386
+ elem_type=vi.type.tensor_type.elem_type,
387
+ shape=[
388
+ "num_blocks",
389
+ "num_heads",
390
+ "head_size",
391
+ "block_size",
392
+ ],
393
+ )
394
+ new_inputs.extend([vi_cache])
395
+ else:
396
+ if "past_key" in vi.name or "past_value" in vi.name:
397
+ vi_cache = helper.make_tensor_value_info(
398
+ vi.name,
399
+ elem_type=vi.type.tensor_type.elem_type,
400
+ shape=[
401
+ "batch_size",
402
+ self.num_attention_heads,
403
+ "past_seq_len",
404
+ self.hidden_size // self.num_attention_heads,
405
+ ],
406
+ )
407
+ new_inputs.extend([vi_cache])
408
+
409
+ graph.ClearField("input")
410
+ graph.input.extend(new_inputs)
411
+
412
+ new_outputs = []
413
+ for i, vi in enumerate(graph.output):
414
+ if i == 0:
415
+ new_outputs.extend([vi])
416
+ else:
417
+ if self.use_attn:
418
+ if "present_key" in vi.name:
419
+ vi_cache = helper.make_tensor_value_info(
420
+ vi.name.replace("present_key", "present"),
421
+ elem_type=vi.type.tensor_type.elem_type,
422
+ shape=[
423
+ 2,
424
+ "batch_size",
425
+ self.num_attention_heads,
426
+ "total_seq_len",
427
+ self.hidden_size // self.num_attention_heads,
428
+ ],
429
+ )
430
+ new_outputs.extend([vi_cache])
431
+ elif self.use_vllm:
432
+ pass
433
+ else:
434
+ vi_cache = helper.make_tensor_value_info(
435
+ vi.name,
436
+ elem_type=vi.type.tensor_type.elem_type,
437
+ shape=[
438
+ "batch_size",
439
+ self.num_attention_heads,
440
+ "total_seq_len",
441
+ self.hidden_size // self.num_attention_heads,
442
+ ],
443
+ )
444
+ new_outputs.extend([vi_cache])
445
+
446
+ graph.ClearField("output")
447
+ graph.output.extend(new_outputs)
448
+
449
+ def preprocess_onnx(self, attn_op_type: AttentionOpType):
450
+ function_name = None
451
+ for func in self.model.functions:
452
+ if func.name.endswith(self.func_name):
453
+ function_name = func.name
454
+ break
455
+ assert function_name is not None
456
+ self.unroll_function(function_name)
457
+ self.update_edges(self.get_phi2_edge_dict())
458
+ self.simplify_phi2_op_type()
459
+ self.remove_dropout_layer()
460
+ if attn_op_type == AttentionOpType.PagedAttention:
461
+ self.remove_lm_head_layer()
462
+ self.process_graph_io(attn_op_type)
463
+
464
+
465
+ class FissionTransformerEmbeddingPhi(Fission):
466
+ def __init__(
467
+ self,
468
+ model: OnnxModel,
469
+ ):
470
+ super().__init__(model, ["torch_nn_modules_sparse_Embedding_model_embed_tokens_1"])
471
+
472
+ def fuse(self, node, input_name_to_nodes, output_name_to_node):
473
+ logger.info("Optimizing %s...", node.name)
474
+
475
+ assert len(node.input) == 2
476
+ assert len(node.output) == 1
477
+
478
+ input = node.input[0]
479
+ output = node.output[0]
480
+
481
+ embedding = self.get_input_by_name(node, "embed_tokens.weight")
482
+
483
+ layer_known_edges_names = [input, output, embedding]
484
+
485
+ subgraph_nodes = [
486
+ helper.make_node(
487
+ "Gather",
488
+ inputs=[embedding, input],
489
+ outputs=[output],
490
+ name="Embedding_Gather",
491
+ ),
492
+ ]
493
+
494
+ self.set_unique_name_and_add_nodes(subgraph_nodes, 0, layer_known_edges_names)
495
+ self.nodes_to_remove.append(node)
496
+ self.prune_graph = True
497
+
498
+
499
+ class FissionTransformerLayerNormPhi(Fission):
500
+ def __init__(
501
+ self,
502
+ model: OnnxModel,
503
+ ):
504
+ super().__init__(model, ["torch_nn_modules_normalization_LayerNorm_model_final_layernorm_1"])
505
+
506
+ def fuse(self, node, input_name_to_nodes, output_name_to_node):
507
+ logger.info("Optimizing %s...", node.name)
508
+
509
+ assert len(node.input) == 3
510
+ assert len(node.output) == 1
511
+
512
+ input = node.input[0]
513
+ output = node.output[0]
514
+
515
+ ln_weight = self.get_input_by_name(node, "final_layernorm.weight")
516
+ ln_bias = self.get_input_by_name(node, "final_layernorm.bias")
517
+
518
+ layer_known_edges_names = [input, output, ln_weight, ln_bias]
519
+
520
+ subgraph_nodes = []
521
+ subgraph_nodes.extend(self.layernorm([input, ln_weight, ln_bias], [output], "Final"))
522
+
523
+ self.set_unique_name_and_add_nodes(subgraph_nodes, 99, layer_known_edges_names)
524
+
525
+ self.replace_fp32_value_info(input, ["batch_size", "seq_len", "hidden_size"])
526
+ self.replace_fp32_value_info(output, ["batch_size", "seq_len", "hidden_size"])
527
+
528
+ self.nodes_to_remove.append(node)
529
+ self.prune_graph = True
530
+
531
+
532
+ class FissionTransformerCausalLMHeadPhi(Fission):
533
+ def __init__(
534
+ self,
535
+ model: OnnxModel,
536
+ ):
537
+ super().__init__(model, ["torch_nn_modules_linear_Linear_lm_head_1"])
538
+
539
+ def fuse(self, node, input_name_to_nodes, output_name_to_node):
540
+ logger.info("Optimizing %s...", node.name)
541
+
542
+ assert len(node.input) == 5
543
+ assert len(node.output) == 1
544
+
545
+ input = node.input[2]
546
+ output = node.output[0]
547
+
548
+ fc_weight = self.process_initializer(self.get_input_by_name(node, "lm_head.weight"), ProcessGemmWFunc())
549
+ fc_bias = self.get_input_by_name(node, "lm_head.bias")
550
+
551
+ layer_known_edges_names = [input, output, fc_weight, fc_bias]
552
+
553
+ subgraph_nodes = []
554
+ subgraph_nodes.extend(self.gemm([input, fc_weight, fc_bias], [output], "LMHead_"))
555
+
556
+ self.set_unique_name_and_add_nodes(subgraph_nodes, 99, layer_known_edges_names)
557
+
558
+ self.replace_fp32_value_info(input, ["batch_size", "seq_len", "hidden_size"])
559
+ self.replace_fp32_value_info(output, ["batch_size", "seq_len", 51200])
560
+
561
+ self.nodes_to_remove.append(node)
562
+ self.prune_graph = True
563
+
564
+
565
+ class FissionTransformerBlockPhi(Fission):
566
+ def __init__(
567
+ self,
568
+ model: OnnxModel,
569
+ num_heads: int,
570
+ ):
571
+ self.num_heads = num_heads
572
+ max_num_layers = 32
573
+ self.func_to_layer_id = {}
574
+ nodes_to_find = []
575
+ for layer in range(max_num_layers):
576
+ func_name = f"modeling_phi_PhiDecoderLayer_model_layers_{layer}_1"
577
+ nodes_to_find.append(func_name)
578
+ self.func_to_layer_id[func_name] = layer
579
+
580
+ super().__init__(model, nodes_to_find)
581
+
582
+ def get_layer_id(self, node):
583
+ return self.func_to_layer_id[node.op_type]
584
+
585
+ def get_gqa_aux_nodes(self):
586
+ gqa_aux_nodes = [
587
+ helper.make_node(
588
+ "Cast",
589
+ inputs=["attention_mask"],
590
+ outputs=["mask_int64"],
591
+ name="Cast_gqa_aux_0",
592
+ to=TensorProto.INT64,
593
+ ),
594
+ helper.make_node(
595
+ "ReduceSum",
596
+ inputs=["mask_int64", "one"],
597
+ outputs=["mask_row_sums"],
598
+ name="ReduceSum_gqa_aux",
599
+ ),
600
+ helper.make_node(
601
+ "Sub",
602
+ inputs=["mask_row_sums", "one"],
603
+ outputs=["seqlens_k_int64"],
604
+ name="Sub_gqa_aux",
605
+ ),
606
+ helper.make_node(
607
+ "Cast",
608
+ inputs=["seqlens_k_int64"],
609
+ outputs=["seqlens_k"],
610
+ name="Cast_gqa_aux_1",
611
+ to=TensorProto.INT32,
612
+ ),
613
+ helper.make_node("Shape", inputs=["mask_int64"], outputs=["mask_shape"], name="Shape_gqa_aux_0"),
614
+ helper.make_node(
615
+ "Gather",
616
+ inputs=["mask_shape", "one"],
617
+ outputs=["total_seq_len_int64"],
618
+ name="Gather_gqa_aux_0",
619
+ axis=0,
620
+ ),
621
+ helper.make_node(
622
+ "Cast",
623
+ inputs=["total_seq_len_int64"],
624
+ outputs=["total_sequence_length"],
625
+ name="Cast_gqa_aux_2",
626
+ to=TensorProto.INT32,
627
+ ),
628
+ ]
629
+ return gqa_aux_nodes
630
+
631
+ def pack_qkv_gemm(self, q_w, k_w, v_w, q_b, k_b, v_b, weight_name, bias_name):
632
+ q_weight = self.model.get_initializer(q_w)
633
+ k_weight = self.model.get_initializer(k_w)
634
+ v_weight = self.model.get_initializer(v_w)
635
+ qw = np.transpose(NumpyHelper.to_array(q_weight), (1, 0))
636
+ kw = np.transpose(NumpyHelper.to_array(k_weight), (1, 0))
637
+ vw = np.transpose(NumpyHelper.to_array(v_weight), (1, 0))
638
+ qkv_weight = np.stack((qw, kw, vw), axis=1)
639
+
640
+ q_bias = self.model.get_initializer(q_b)
641
+ k_bias = self.model.get_initializer(k_b)
642
+ v_bias = self.model.get_initializer(v_b)
643
+ qb = NumpyHelper.to_array(q_bias)
644
+ kb = NumpyHelper.to_array(k_bias)
645
+ vb = NumpyHelper.to_array(v_bias)
646
+ qkv_bias = np.stack((qb, kb, vb), axis=0)
647
+
648
+ hidden_size = qkv_weight.shape[0]
649
+
650
+ weight = helper.make_tensor(
651
+ weight_name,
652
+ data_type=TensorProto.FLOAT,
653
+ dims=[hidden_size, hidden_size * 3],
654
+ vals=qkv_weight.flatten().tobytes(),
655
+ raw=True,
656
+ )
657
+ self.model.add_initializer(weight, self.this_graph_name)
658
+
659
+ bias = helper.make_tensor(
660
+ bias_name,
661
+ data_type=TensorProto.FLOAT,
662
+ dims=[hidden_size * 3],
663
+ vals=qkv_bias.flatten().tobytes(),
664
+ raw=True,
665
+ )
666
+ self.model.add_initializer(bias, self.this_graph_name)
667
+
668
+ self.add_fp32_value_info(weight.name)
669
+ self.add_fp32_value_info(bias.name)
670
+
671
+ return weight_name, bias_name
672
+
673
+ def fuse(
674
+ self,
675
+ node,
676
+ input_name_to_nodes,
677
+ output_name_to_node,
678
+ ):
679
+ logger.info("Optimizing %s...", node.name)
680
+
681
+ logger.info(f"AttentionOpType: {self.attn_op_type}")
682
+
683
+ layer_id = self.get_layer_id(node)
684
+
685
+ i_hidden_states = node.input[0]
686
+ i_key_cache = self.get_input_by_name(node, "past_key")
687
+ i_value_cache = self.get_input_by_name(node, "past_value")
688
+
689
+ o_hidden_states = node.output[-1]
690
+ o_key_cache = self.get_output_by_name(node, "present_key")
691
+ o_value_cache = self.get_output_by_name(node, "present_value")
692
+
693
+ ln_weight = self.get_input_by_name(node, "input_layernorm.weight")
694
+ ln_bias = self.get_input_by_name(node, "input_layernorm.bias")
695
+
696
+ attn_q_weight, attn_q_bias, attn_k_weight, attn_k_bias, attn_v_weight, attn_v_bias = (
697
+ None,
698
+ None,
699
+ None,
700
+ None,
701
+ None,
702
+ None,
703
+ )
704
+ attn_qkv_weight, attn_qkv_bias = None, None
705
+ cos_cache, sin_cache = None, None
706
+
707
+ if self.attn_op_type != AttentionOpType.Attention:
708
+ attn_q_weight = self.process_initializer(
709
+ self.get_input_by_name(node, "self_attn.q_proj.weight"), ProcessGemmWFunc()
710
+ )
711
+ attn_k_weight = self.process_initializer(
712
+ self.get_input_by_name(node, "self_attn.k_proj.weight"), ProcessGemmWFunc()
713
+ )
714
+ attn_v_weight = self.process_initializer(
715
+ self.get_input_by_name(node, "self_attn.v_proj.weight"), ProcessGemmWFunc()
716
+ )
717
+ attn_q_bias = self.get_input_by_name(node, "self_attn.q_proj.bias")
718
+ attn_k_bias = self.get_input_by_name(node, "self_attn.k_proj.bias")
719
+ attn_v_bias = self.get_input_by_name(node, "self_attn.v_proj.bias")
720
+
721
+ cos_cache = self.process_initializer(
722
+ self.get_input_by_name(node, "rotary_emb.cos_cached"), ProcessRotCacheFunc()
723
+ )
724
+ sin_cache = self.process_initializer(
725
+ self.get_input_by_name(node, "rotary_emb.sin_cached"), ProcessRotCacheFunc()
726
+ )
727
+ else:
728
+ attn_qkv_weight, attn_qkv_bias = self.pack_qkv_gemm(
729
+ self.get_input_by_name(node, "self_attn.q_proj.weight"),
730
+ self.get_input_by_name(node, "self_attn.k_proj.weight"),
731
+ self.get_input_by_name(node, "self_attn.v_proj.weight"),
732
+ self.get_input_by_name(node, "self_attn.q_proj.bias"),
733
+ self.get_input_by_name(node, "self_attn.k_proj.bias"),
734
+ self.get_input_by_name(node, "self_attn.v_proj.bias"),
735
+ self.get_uname(layer_id, "attn_qkv_weight"),
736
+ self.get_uname(layer_id, "attn_qkv_bias"),
737
+ )
738
+
739
+ attn_out_weight = self.process_initializer(
740
+ self.get_input_by_name(node, "self_attn.dense.weight"), ProcessGemmWFunc()
741
+ )
742
+ attn_out_bias = self.get_input_by_name(node, "self_attn.dense.bias")
743
+
744
+ mlp_fc1_weight = self.process_initializer(self.get_input_by_name(node, "mlp.fc1.weight"), ProcessGemmWFunc())
745
+ mlp_fc2_weight = self.process_initializer(self.get_input_by_name(node, "mlp.fc2.weight"), ProcessGemmWFunc())
746
+ mlp_fc1_bias = self.get_input_by_name(node, "mlp.fc1.bias")
747
+ mlp_fc2_bias = self.get_input_by_name(node, "mlp.fc2.bias")
748
+
749
+ layer_known_edges_names = []
750
+ layer_known_edges_names.extend([i_hidden_states, i_key_cache, i_value_cache])
751
+ layer_known_edges_names.extend([o_hidden_states, o_key_cache, o_value_cache])
752
+ layer_known_edges_names.extend([ln_weight, ln_bias])
753
+ if self.attn_op_type != AttentionOpType.Attention:
754
+ layer_known_edges_names.extend(
755
+ [
756
+ attn_q_weight,
757
+ attn_q_bias,
758
+ attn_k_weight,
759
+ attn_k_bias,
760
+ attn_v_weight,
761
+ attn_v_bias,
762
+ cos_cache,
763
+ sin_cache,
764
+ ]
765
+ )
766
+ else:
767
+ layer_known_edges_names.extend([attn_qkv_weight, attn_qkv_bias])
768
+ layer_known_edges_names.extend(
769
+ [attn_out_weight, attn_out_bias, mlp_fc1_weight, mlp_fc1_bias, mlp_fc2_weight, mlp_fc2_bias]
770
+ )
771
+ layer_known_edges_names.extend(
772
+ ["attention_mask", "step", "seqlens_k", "total_sequence_length", "input_metadata", "position_ids"]
773
+ )
774
+
775
+ subgraph_nodes = []
776
+ subgraph_nodes.extend(self.layernorm([i_hidden_states, ln_weight, ln_bias], ["ln_out"]))
777
+ subgraph_nodes.extend(self.gemm(["attn_out", attn_out_weight, attn_out_bias], ["attn_add_out"], "OutProj_"))
778
+ subgraph_nodes.extend(self.gemm(["ln_out", mlp_fc1_weight, mlp_fc1_bias], ["fc1_out"], "FC1_"))
779
+ subgraph_nodes.extend(self.fastgelu(["fc1_out"], ["gelu_out"]))
780
+ subgraph_nodes.extend(self.gemm(["gelu_out", mlp_fc2_weight, mlp_fc2_bias], ["fc2_out"], "FC2_"))
781
+ subgraph_nodes.extend(self.add(["attn_add_out", "fc2_out"], ["residual_1_out"], "Residual_1"))
782
+ subgraph_nodes.extend(self.add([i_hidden_states, "residual_1_out"], [o_hidden_states], "Residual_2"))
783
+ if self.attn_op_type != AttentionOpType.Attention:
784
+ subgraph_nodes.extend(self.gemm(["ln_out", attn_q_weight, attn_q_bias], ["query"], "Q_"))
785
+ subgraph_nodes.extend(self.gemm(["ln_out", attn_k_weight, attn_k_bias], ["key"], "K_"))
786
+ subgraph_nodes.extend(self.gemm(["ln_out", attn_v_weight, attn_v_bias], ["value"], "V_"))
787
+ # vllm engine requires full position ids as the input
788
+ pos_ids_name = "position_ids" if self.attn_op_type == AttentionOpType.PagedAttention else "step"
789
+ subgraph_nodes.extend(self.rotary(["query", pos_ids_name, cos_cache, sin_cache], ["query_rot"], "Q_"))
790
+ subgraph_nodes.extend(self.rotary(["key", pos_ids_name, cos_cache, sin_cache], ["key_rot"], "K_"))
791
+ if self.attn_op_type == AttentionOpType.MultiHeadAttention:
792
+ subgraph_nodes.extend(
793
+ self.mha(
794
+ ["query_rot", "key_rot", "value", "", "attention_mask", "", i_key_cache, i_value_cache],
795
+ ["attn_out", o_key_cache, o_value_cache],
796
+ )
797
+ )
798
+ elif self.attn_op_type == AttentionOpType.GroupQueryAttention:
799
+ subgraph_nodes.extend(
800
+ self.gqa(
801
+ [
802
+ "query_rot",
803
+ "key_rot",
804
+ "value",
805
+ i_key_cache,
806
+ i_value_cache,
807
+ "seqlens_k",
808
+ "total_sequence_length",
809
+ ],
810
+ ["attn_out", o_key_cache, o_value_cache],
811
+ )
812
+ )
813
+ if layer_id == 0:
814
+ gqa_aux_nodes = self.get_gqa_aux_nodes()
815
+ for new_node in gqa_aux_nodes:
816
+ self.nodes_to_add.append(new_node)
817
+ self.node_name_to_graph_name[new_node.name] = self.this_graph_name
818
+ self.model.add_initializer(
819
+ numpy_helper.from_array(np.array([1], dtype="int64"), name="one"), self.this_graph_name
820
+ )
821
+ elif self.attn_op_type == AttentionOpType.PagedAttention:
822
+ subgraph_nodes.extend(
823
+ self.paged_attn(
824
+ ["query_rot", "key_rot", "value", i_key_cache, i_value_cache, "input_metadata"],
825
+ ["attn_out"],
826
+ )
827
+ )
828
+ else:
829
+ past_name = f"past_{layer_id}"
830
+ present_name = f"present_{layer_id}"
831
+ layer_known_edges_names.extend([past_name, present_name])
832
+ subgraph_nodes.extend(
833
+ self.attention(
834
+ ["ln_out", attn_qkv_weight, attn_qkv_bias, "attention_mask", past_name], ["attn_out", present_name]
835
+ )
836
+ )
837
+
838
+ self.set_unique_name_and_add_nodes(subgraph_nodes, layer_id, layer_known_edges_names)
839
+
840
+ self.replace_fp32_value_info(i_hidden_states, ["batch_size", "seq_len", "hidden_size"])
841
+ self.replace_fp32_value_info(o_hidden_states, ["batch_size", "seq_len", "hidden_size"])
842
+
843
+ self.nodes_to_remove.append(node)
844
+ self.prune_graph = True
845
+
846
+
847
+ class PhiOnnxModel(OnnxModel):
848
+ def __init__(self, model: ModelProto, num_heads: int, hidden_size: int):
849
+ super().__init__(model)
850
+ self.phi2_preprocessor = Phi2PreProcessor(self.model, num_heads, hidden_size)
851
+ self.fission_transformer_block = FissionTransformerBlockPhi(self, num_heads)
852
+ self.fission_causal_lm_head = FissionTransformerCausalLMHeadPhi(self)
853
+ self.fission_transformer_layernorm = FissionTransformerLayerNormPhi(self)
854
+ self.fission_transformer_embedding = FissionTransformerEmbeddingPhi(self)
855
+
856
+ def optimize(self, options: Optional[FusionOptions] = None, add_dynamic_axes: bool = False):
857
+ assert options is not None
858
+ attn_op_type = options.attention_op_type
859
+
860
+ self.fission_transformer_block.set_attention_op_type(attn_op_type)
861
+
862
+ self.phi2_preprocessor.preprocess_onnx(attn_op_type)
863
+
864
+ self.fission_transformer_block.apply()
865
+ self.fission_transformer_layernorm.apply()
866
+ self.fission_causal_lm_head.apply()
867
+ self.fission_transformer_embedding.apply()
868
+
869
+ super().prune_graph()
870
+
871
+ # SLN ctor is placed here intentionally to delay the symbolic shape inference
872
+ self.fuse_sln = FusionSkipLayerNormalization(self)
873
+ self.fuse_bias_sln = FusionBiasSkipLayerNormalization(self)
874
+ self.fuse_sln.apply()
875
+ self.fuse_bias_sln.apply()
876
+
877
+ def get_fused_operator_statistics(self):
878
+ """
879
+ Returns node count of fused operators.
880
+ """
881
+ op_count = {}
882
+ ops = [
883
+ "Attention",
884
+ "MultiHeadAttention",
885
+ "GroupQueryAttention",
886
+ "PagedAttention",
887
+ "Gelu",
888
+ "BiasGelu",
889
+ "FastGelu",
890
+ "LayerNormalization",
891
+ "SkipLayerNormalization",
892
+ ]
893
+ for op in ops:
894
+ nodes = self.get_nodes_by_op_type(op)
895
+ op_count[op] = len(nodes)
896
+
897
+ logger.info(f"Optimized operators: {op_count}")
898
+ return op_count
899
+
900
+ def is_fully_optimized(self, fused_op_count=None):
901
+ """
902
+ Returns True when the model is fully optimized.
903
+ """
904
+ if fused_op_count is None:
905
+ fused_op_count = self.get_fused_operator_statistics()
906
+
907
+ def op_count(op_name: str):
908
+ return fused_op_count.get(op_name) or 0
909
+
910
+ attention = (
911
+ op_count("Attention")
912
+ + op_count("MultiHeadAttention")
913
+ + op_count("GroupQueryAttention")
914
+ + op_count("PagedAttention")
915
+ )
916
+ gelu = op_count("Gelu") + op_count("BiasGelu") + op_count("FastGelu")
917
+ layer_norm = op_count("LayerNormalization") + op_count("SkipLayerNormalization")
918
+
919
+ is_perfect = (attention > 0) and (attention == gelu) and (layer_norm >= attention)
920
+
921
+ if layer_norm == 0:
922
+ logger.debug("Layer Normalization not fused")
923
+
924
+ if gelu == 0:
925
+ logger.debug("Gelu (or FastGelu) not fused")
926
+
927
+ if attention == 0:
928
+ logger.warning("Attention (or MultiHeadAttention) not fused")
929
+
930
+ return is_perfect