onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (322) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6121 -0
  4. onnxruntime/__init__.py +418 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +175 -0
  7. onnxruntime/backend/backend_rep.py +52 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/build_and_package_info.py +2 -0
  13. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  14. onnxruntime/capi/onnxruntime.dll +0 -0
  15. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  16. onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
  17. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  18. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  19. onnxruntime/capi/onnxruntime_validation.py +154 -0
  20. onnxruntime/capi/version_info.py +2 -0
  21. onnxruntime/datasets/__init__.py +18 -0
  22. onnxruntime/datasets/logreg_iris.onnx +0 -0
  23. onnxruntime/datasets/mul_1.onnx +0 -0
  24. onnxruntime/datasets/sigmoid.onnx +13 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  27. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  28. onnxruntime/quantization/__init__.py +19 -0
  29. onnxruntime/quantization/base_quantizer.py +529 -0
  30. onnxruntime/quantization/calibrate.py +1267 -0
  31. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  32. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  33. onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
  34. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  35. onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
  36. onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
  37. onnxruntime/quantization/fusions/__init__.py +4 -0
  38. onnxruntime/quantization/fusions/fusion.py +311 -0
  39. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  40. onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
  41. onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
  42. onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
  43. onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
  44. onnxruntime/quantization/neural_compressor/__init__.py +1 -0
  45. onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
  46. onnxruntime/quantization/neural_compressor/util.py +80 -0
  47. onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
  48. onnxruntime/quantization/onnx_model.py +600 -0
  49. onnxruntime/quantization/onnx_quantizer.py +1163 -0
  50. onnxruntime/quantization/operators/__init__.py +2 -0
  51. onnxruntime/quantization/operators/activation.py +119 -0
  52. onnxruntime/quantization/operators/argmax.py +18 -0
  53. onnxruntime/quantization/operators/attention.py +73 -0
  54. onnxruntime/quantization/operators/base_operator.py +26 -0
  55. onnxruntime/quantization/operators/binary_op.py +72 -0
  56. onnxruntime/quantization/operators/concat.py +62 -0
  57. onnxruntime/quantization/operators/conv.py +260 -0
  58. onnxruntime/quantization/operators/direct_q8.py +78 -0
  59. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  60. onnxruntime/quantization/operators/gather.py +64 -0
  61. onnxruntime/quantization/operators/gavgpool.py +62 -0
  62. onnxruntime/quantization/operators/gemm.py +172 -0
  63. onnxruntime/quantization/operators/lstm.py +121 -0
  64. onnxruntime/quantization/operators/matmul.py +231 -0
  65. onnxruntime/quantization/operators/maxpool.py +34 -0
  66. onnxruntime/quantization/operators/norm.py +40 -0
  67. onnxruntime/quantization/operators/pad.py +172 -0
  68. onnxruntime/quantization/operators/pooling.py +67 -0
  69. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  70. onnxruntime/quantization/operators/resize.py +34 -0
  71. onnxruntime/quantization/operators/softmax.py +74 -0
  72. onnxruntime/quantization/operators/split.py +63 -0
  73. onnxruntime/quantization/operators/where.py +87 -0
  74. onnxruntime/quantization/preprocess.py +141 -0
  75. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  76. onnxruntime/quantization/qdq_quantizer.py +1477 -0
  77. onnxruntime/quantization/quant_utils.py +1051 -0
  78. onnxruntime/quantization/quantize.py +953 -0
  79. onnxruntime/quantization/registry.py +110 -0
  80. onnxruntime/quantization/shape_inference.py +204 -0
  81. onnxruntime/quantization/static_quantize_runner.py +256 -0
  82. onnxruntime/quantization/tensor_quant_overrides.py +520 -0
  83. onnxruntime/tools/__init__.py +10 -0
  84. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  85. onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
  86. onnxruntime/tools/file_utils.py +47 -0
  87. onnxruntime/tools/logger.py +11 -0
  88. onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
  89. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  90. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
  91. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  92. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  93. onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
  94. onnxruntime/tools/offline_tuning.py +169 -0
  95. onnxruntime/tools/onnx_model_utils.py +416 -0
  96. onnxruntime/tools/onnx_randomizer.py +85 -0
  97. onnxruntime/tools/onnxruntime_test.py +164 -0
  98. onnxruntime/tools/optimize_onnx_model.py +56 -0
  99. onnxruntime/tools/ort_format_model/__init__.py +27 -0
  100. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  140. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  141. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  142. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  143. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  144. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  145. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  146. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  147. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  148. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  149. onnxruntime/tools/ort_format_model/types.py +85 -0
  150. onnxruntime/tools/ort_format_model/utils.py +61 -0
  151. onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
  152. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  153. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  154. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  155. onnxruntime/tools/qnn/add_trans_cast.py +292 -0
  156. onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
  157. onnxruntime/tools/qnn/preprocess.py +165 -0
  158. onnxruntime/tools/reduced_build_config_parser.py +203 -0
  159. onnxruntime/tools/remove_initializer_from_input.py +37 -0
  160. onnxruntime/tools/symbolic_shape_infer.py +3094 -0
  161. onnxruntime/tools/update_onnx_opset.py +31 -0
  162. onnxruntime/transformers/__init__.py +8 -0
  163. onnxruntime/transformers/affinity_helper.py +40 -0
  164. onnxruntime/transformers/benchmark.py +942 -0
  165. onnxruntime/transformers/benchmark_helper.py +643 -0
  166. onnxruntime/transformers/bert_perf_test.py +629 -0
  167. onnxruntime/transformers/bert_test_data.py +641 -0
  168. onnxruntime/transformers/compare_bert_results.py +256 -0
  169. onnxruntime/transformers/constants.py +47 -0
  170. onnxruntime/transformers/convert_generation.py +3605 -0
  171. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  172. onnxruntime/transformers/convert_to_packing_mode.py +385 -0
  173. onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
  174. onnxruntime/transformers/float16.py +501 -0
  175. onnxruntime/transformers/fusion_attention.py +1189 -0
  176. onnxruntime/transformers/fusion_attention_clip.py +340 -0
  177. onnxruntime/transformers/fusion_attention_sam2.py +533 -0
  178. onnxruntime/transformers/fusion_attention_unet.py +1307 -0
  179. onnxruntime/transformers/fusion_attention_vae.py +300 -0
  180. onnxruntime/transformers/fusion_bart_attention.py +435 -0
  181. onnxruntime/transformers/fusion_base.py +141 -0
  182. onnxruntime/transformers/fusion_bias_add.py +57 -0
  183. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  184. onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
  185. onnxruntime/transformers/fusion_conformer_attention.py +222 -0
  186. onnxruntime/transformers/fusion_constant_fold.py +144 -0
  187. onnxruntime/transformers/fusion_embedlayer.py +810 -0
  188. onnxruntime/transformers/fusion_fastgelu.py +492 -0
  189. onnxruntime/transformers/fusion_gelu.py +258 -0
  190. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  191. onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
  192. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  193. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  194. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  195. onnxruntime/transformers/fusion_group_norm.py +180 -0
  196. onnxruntime/transformers/fusion_layernorm.py +489 -0
  197. onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
  198. onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
  199. onnxruntime/transformers/fusion_options.py +340 -0
  200. onnxruntime/transformers/fusion_qordered_attention.py +420 -0
  201. onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
  202. onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
  203. onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
  204. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  205. onnxruntime/transformers/fusion_reshape.py +173 -0
  206. onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
  207. onnxruntime/transformers/fusion_shape.py +109 -0
  208. onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
  209. onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
  210. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  211. onnxruntime/transformers/fusion_transpose.py +167 -0
  212. onnxruntime/transformers/fusion_utils.py +321 -0
  213. onnxruntime/transformers/huggingface_models.py +74 -0
  214. onnxruntime/transformers/import_utils.py +20 -0
  215. onnxruntime/transformers/io_binding_helper.py +487 -0
  216. onnxruntime/transformers/large_model_exporter.py +395 -0
  217. onnxruntime/transformers/machine_info.py +230 -0
  218. onnxruntime/transformers/metrics.py +163 -0
  219. onnxruntime/transformers/models/bart/__init__.py +12 -0
  220. onnxruntime/transformers/models/bart/export.py +98 -0
  221. onnxruntime/transformers/models/bert/__init__.py +12 -0
  222. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  223. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  224. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  225. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
  226. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
  227. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  228. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  229. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  230. onnxruntime/transformers/models/llama/__init__.py +12 -0
  231. onnxruntime/transformers/models/llama/benchmark.py +700 -0
  232. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  233. onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
  234. onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
  235. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  236. onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
  237. onnxruntime/transformers/models/llama/llama_parity.py +343 -0
  238. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  239. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  240. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  241. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  242. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  243. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  244. onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
  245. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  246. onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
  247. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  248. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  249. onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
  250. onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
  251. onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
  252. onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
  253. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  254. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  255. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  256. onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
  257. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
  258. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  259. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  260. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
  261. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  262. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
  263. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
  264. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  265. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
  266. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
  267. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
  268. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
  269. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  270. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  271. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  272. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
  273. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  274. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  275. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  276. onnxruntime/transformers/models/t5/__init__.py +12 -0
  277. onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
  278. onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
  279. onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
  280. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
  281. onnxruntime/transformers/models/t5/t5_helper.py +302 -0
  282. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  283. onnxruntime/transformers/models/whisper/benchmark.py +585 -0
  284. onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
  285. onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
  286. onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
  287. onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
  288. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  289. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
  290. onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
  291. onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
  292. onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
  293. onnxruntime/transformers/onnx_exporter.py +719 -0
  294. onnxruntime/transformers/onnx_model.py +1636 -0
  295. onnxruntime/transformers/onnx_model_bart.py +141 -0
  296. onnxruntime/transformers/onnx_model_bert.py +488 -0
  297. onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
  298. onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
  299. onnxruntime/transformers/onnx_model_clip.py +42 -0
  300. onnxruntime/transformers/onnx_model_conformer.py +32 -0
  301. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  302. onnxruntime/transformers/onnx_model_mmdit.py +112 -0
  303. onnxruntime/transformers/onnx_model_phi.py +929 -0
  304. onnxruntime/transformers/onnx_model_sam2.py +137 -0
  305. onnxruntime/transformers/onnx_model_t5.py +985 -0
  306. onnxruntime/transformers/onnx_model_tnlr.py +226 -0
  307. onnxruntime/transformers/onnx_model_unet.py +258 -0
  308. onnxruntime/transformers/onnx_model_vae.py +42 -0
  309. onnxruntime/transformers/onnx_utils.py +55 -0
  310. onnxruntime/transformers/optimizer.py +620 -0
  311. onnxruntime/transformers/past_helper.py +149 -0
  312. onnxruntime/transformers/profile_result_processor.py +358 -0
  313. onnxruntime/transformers/profiler.py +434 -0
  314. onnxruntime/transformers/quantize_helper.py +76 -0
  315. onnxruntime/transformers/shape_infer_helper.py +121 -0
  316. onnxruntime/transformers/shape_optimizer.py +400 -0
  317. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  318. onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
  319. onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
  320. onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
  321. onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
  322. onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,929 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+
6
+ from logging import getLogger
7
+
8
+ import numpy as np
9
+ from dynamo_onnx_helper import DynamoOnnxHelper
10
+ from fusion_base import Fusion
11
+ from fusion_options import AttentionOpType, FusionOptions
12
+ from fusion_skiplayernorm import FusionBiasSkipLayerNormalization, FusionSkipLayerNormalization
13
+ from fusion_utils import NumpyHelper
14
+ from onnx import ModelProto, NodeProto, TensorProto, helper, numpy_helper
15
+ from onnx_model import OnnxModel
16
+
17
+ logger = getLogger(__name__)
18
+
19
+
20
+ class ProcessGemmWFunc:
21
+ def __call__(self, x):
22
+ return np.transpose(x, (1, 0))
23
+
24
+
25
+ class ProcessMatMulQFunc:
26
+ def __call__(self, x):
27
+ return np.transpose(np.split(x, 3, 0)[0], (1, 0))
28
+
29
+
30
+ class ProcessMatMulKFunc:
31
+ def __call__(self, x):
32
+ return np.transpose(np.split(x, 3, 0)[1], (1, 0))
33
+
34
+
35
+ class ProcessMatMulVFunc:
36
+ def __call__(self, x):
37
+ return np.transpose(np.split(x, 3, 0)[2], (1, 0))
38
+
39
+
40
+ class ProcessBiasQFunc:
41
+ def __call__(self, x):
42
+ x = np.split(x, 3, -1)[0]
43
+ return x
44
+
45
+
46
+ class ProcessBiasKFunc:
47
+ def __call__(self, x):
48
+ x = np.split(x, 3, -1)[1]
49
+ return x
50
+
51
+
52
+ class ProcessBiasVFunc:
53
+ def __call__(self, x):
54
+ x = np.split(x, 3, -1)[2]
55
+ return x
56
+
57
+
58
+ class ProcessRotCacheFunc:
59
+ def __call__(self, x):
60
+ # half rotary embedding
61
+ assert len(x.shape) == 2
62
+ if x.shape[1] == 32:
63
+ return x[:, 0:16]
64
+ return x
65
+
66
+
67
+ # TODO: move to a separate file
68
+ class Fission(Fusion):
69
+ def __init__(
70
+ self,
71
+ model: OnnxModel,
72
+ nodes_to_find: list[str],
73
+ ):
74
+ super().__init__(model, "DONOTUSE", nodes_to_find)
75
+
76
+ def set_attention_op_type(self, attn_op_type: AttentionOpType):
77
+ self.attn_op_type = attn_op_type
78
+
79
+ def get_uname(self, layer_id, name):
80
+ return name + "_" + str(layer_id)
81
+
82
+ def get_edge_by_name(self, edges, name):
83
+ for edge in edges:
84
+ if edge == name or edge.endswith(name) or edge.startswith(name):
85
+ return edge
86
+ raise ValueError(f"Edge {name} not found")
87
+
88
+ def get_input_by_name(self, node, name):
89
+ return self.get_edge_by_name(node.input, name)
90
+
91
+ def get_output_by_name(self, node, name):
92
+ return self.get_edge_by_name(node.output, name)
93
+
94
+ def process_initializer(self, initializer_name, functor, custom_name=None):
95
+ i = self.model.get_initializer(initializer_name)
96
+ i_np_array = NumpyHelper.to_array(i)
97
+ processed_i_np_array = functor(i_np_array)
98
+ new_tensor = helper.make_tensor(
99
+ initializer_name + "_processed" if custom_name is None else custom_name,
100
+ data_type=TensorProto.FLOAT,
101
+ dims=processed_i_np_array.shape,
102
+ vals=processed_i_np_array.flatten().tobytes(),
103
+ raw=True,
104
+ )
105
+ self.model.add_initializer(new_tensor, self.this_graph_name)
106
+ return new_tensor.name
107
+
108
+ def add_fp32_value_info(self, name):
109
+ new_value_info = self.model.graph().value_info.add()
110
+ new_value_info.name = name
111
+ new_value_info.type.tensor_type.elem_type = TensorProto.FLOAT
112
+
113
+ def add_int64_value_info(self, name):
114
+ new_value_info = self.model.graph().value_info.add()
115
+ new_value_info.name = name
116
+ new_value_info.type.tensor_type.elem_type = TensorProto.INT64
117
+
118
+ def replace_fp32_value_info(self, name, shape):
119
+ for value_info in self.model.graph().value_info:
120
+ if value_info.name == name:
121
+ self.model.graph().value_info.remove(value_info)
122
+ break
123
+ new_value_info = helper.make_tensor_value_info(
124
+ name,
125
+ elem_type=TensorProto.FLOAT,
126
+ shape=shape,
127
+ )
128
+ self.model.graph().value_info.extend([new_value_info])
129
+
130
+ def set_unique_name_and_add_nodes(
131
+ self, subgraph_nodes: list[NodeProto], layer_id: int, layer_known_edges_names: list[str]
132
+ ):
133
+ for new_node in subgraph_nodes:
134
+ for i, name in enumerate(new_node.input):
135
+ if name == "":
136
+ continue
137
+ elif name not in layer_known_edges_names:
138
+ new_node.input[i] = self.get_uname(layer_id, name)
139
+ self.add_fp32_value_info(new_node.input[i])
140
+ for i, name in enumerate(new_node.output):
141
+ if name == "":
142
+ continue
143
+ elif name not in layer_known_edges_names:
144
+ new_node.output[i] = self.get_uname(layer_id, name)
145
+ self.add_fp32_value_info(new_node.output[i])
146
+ new_node.name = self.get_uname(layer_id, new_node.name)
147
+ self.nodes_to_add.append(new_node)
148
+ self.node_name_to_graph_name[new_node.name] = self.this_graph_name
149
+
150
+ def layernorm(self, inputs: list[str], outputs: list[str], prefix: str = ""):
151
+ assert len(inputs) == 3
152
+ assert len(outputs) == 1
153
+ node = helper.make_node(
154
+ "LayerNormalization",
155
+ inputs=inputs,
156
+ outputs=outputs,
157
+ name=prefix + "_LayerNormalization",
158
+ epsilon=9.999999747378752e-06,
159
+ )
160
+ return [node]
161
+
162
+ def gemm(self, inputs: list[str], outputs: list[str], prefix: str = ""):
163
+ assert len(inputs) == 3
164
+ assert len(outputs) == 1
165
+ matmul = helper.make_node(
166
+ "MatMul",
167
+ inputs=[inputs[0], inputs[1]],
168
+ outputs=[prefix + "matmul_out"],
169
+ name=prefix + "MatMul",
170
+ )
171
+ add = helper.make_node(
172
+ "Add",
173
+ inputs=[prefix + "matmul_out", inputs[2]],
174
+ outputs=outputs,
175
+ name=prefix + "Bias",
176
+ )
177
+ return [matmul, add]
178
+
179
+ def rotary(self, inputs: list[str], outputs: list[str], prefix: str = "", rot_dim=32, num_heads=32):
180
+ assert len(inputs) == 4
181
+ assert len(outputs) == 1
182
+ node = helper.make_node(
183
+ "RotaryEmbedding",
184
+ inputs=inputs,
185
+ outputs=outputs,
186
+ name=prefix + "RotaryEmbedding",
187
+ domain="com.microsoft",
188
+ rotary_embedding_dim=rot_dim,
189
+ num_heads=num_heads,
190
+ )
191
+ return [node]
192
+
193
+ def fastgelu(self, inputs: list[str], outputs: list[str], prefix: str = ""):
194
+ assert len(inputs) == 1
195
+ assert len(outputs) == 1
196
+ node = helper.make_node(
197
+ "FastGelu",
198
+ inputs=inputs,
199
+ outputs=outputs,
200
+ name=prefix + "FastGelu",
201
+ domain="com.microsoft",
202
+ )
203
+ return [node]
204
+
205
+ def add(self, inputs: list[str], outputs: list[str], prefix: str = ""):
206
+ assert len(inputs) == 2
207
+ assert len(outputs) == 1
208
+ node = helper.make_node(
209
+ "Add",
210
+ inputs=inputs,
211
+ outputs=outputs,
212
+ name=prefix + "Add",
213
+ )
214
+ return [node]
215
+
216
+ def mha(self, inputs: list[str], outputs: list[str], prefix: str = "", num_heads=32):
217
+ assert len(inputs) == 8
218
+ assert len(outputs) == 3
219
+ node = helper.make_node(
220
+ "MultiHeadAttention",
221
+ inputs=inputs,
222
+ outputs=outputs,
223
+ name=prefix + "MultiHeadAttention",
224
+ domain="com.microsoft",
225
+ num_heads=num_heads,
226
+ unidirectional=1,
227
+ )
228
+ return [node]
229
+
230
+ def gqa(self, inputs: list[str], outputs: list[str], prefix: str = "", num_heads=32):
231
+ assert len(inputs) == 7
232
+ assert len(outputs) == 3
233
+ node = helper.make_node(
234
+ "GroupQueryAttention",
235
+ inputs=inputs,
236
+ outputs=outputs,
237
+ name=prefix + "GroupQueryAttention",
238
+ domain="com.microsoft",
239
+ num_heads=num_heads,
240
+ kv_num_heads=num_heads,
241
+ )
242
+ return [node]
243
+
244
+ def attention(self, inputs: list[str], outputs: list[str], prefix: str = "", num_heads=32):
245
+ assert len(inputs) == 5
246
+ assert len(outputs) == 2
247
+ node = helper.make_node(
248
+ "Attention",
249
+ inputs=inputs,
250
+ outputs=outputs,
251
+ name=prefix + "Attention",
252
+ domain="com.microsoft",
253
+ num_heads=num_heads,
254
+ unidirectional=1,
255
+ do_rotary=1,
256
+ rotary_embedding_dim=32,
257
+ )
258
+ return [node]
259
+
260
+ def paged_attn(
261
+ self,
262
+ inputs: list[str],
263
+ outputs: list[str],
264
+ prefix: str = "",
265
+ num_heads=32,
266
+ head_size=80,
267
+ scale=0.11180339753627777,
268
+ ):
269
+ assert len(inputs) == 6
270
+ assert len(outputs) == 1
271
+ node = helper.make_node(
272
+ "PagedAttention",
273
+ inputs=inputs,
274
+ outputs=outputs,
275
+ name=prefix + "PagedAttention",
276
+ domain="vllm.ort.ext",
277
+ num_heads=num_heads,
278
+ num_kv_heads=num_heads,
279
+ head_size=head_size,
280
+ scale=scale,
281
+ )
282
+ return [node]
283
+
284
+
285
+ class Phi2PreProcessor(DynamoOnnxHelper):
286
+ def __init__(self, model: ModelProto, num_heads: int, hidden_size: int):
287
+ super().__init__(model)
288
+ self.num_hidden_layers = 32
289
+ self.num_attention_heads = num_heads
290
+ self.hidden_size = hidden_size
291
+
292
+ self.func_name = "modeling_phi_PhiModel_model_1"
293
+
294
+ def get_phi2_edge_dict(self) -> dict:
295
+ edge_dict = {}
296
+ edge_dict["lm_head_1"] = "logits"
297
+ edge_dict["l_input_ids_"] = "input_ids"
298
+ edge_dict["key_states"] = "past_key_0"
299
+ edge_dict["value_states"] = "past_value_0"
300
+ for i in range(1, self.num_hidden_layers, 1):
301
+ edge_dict[f"key_states_{i}"] = f"past_key_{i}"
302
+ edge_dict[f"value_states_{i}"] = f"past_value_{i}"
303
+ edge_dict[f"model_layers_{i}_1"] = f"present_key_{i}"
304
+ edge_dict[f"model_layers_{i}_1_1"] = f"present_value_{i}"
305
+
306
+ outputs = [o.name for o in self.model.graph.output]
307
+ if "model_layers_0_1_1" in outputs and "model_layers_0_1_2" in outputs:
308
+ edge_dict["model_layers_0_1_1"] = "present_key_0"
309
+ edge_dict["model_layers_0_1_2"] = "present_value_0"
310
+ else:
311
+ assert "model_layers_0_1" in outputs and "model_layers_0_1_1" in outputs
312
+ edge_dict["model_layers_0_1"] = "present_key_0"
313
+ edge_dict["model_layers_0_1_1"] = "present_value_0"
314
+ return edge_dict
315
+
316
+ def simplify_phi2_op_type(self):
317
+ phi2_transformer_layer_name = "modeling_phi_PhiDecoderLayer_model_layers"
318
+ for node in self.model.graph.node:
319
+ index = node.op_type.find(phi2_transformer_layer_name)
320
+ if index != -1:
321
+ node.op_type = node.op_type[index:]
322
+
323
+ def process_graph_io(self, attn_op_type: AttentionOpType):
324
+ self.use_attn = attn_op_type == AttentionOpType.Attention
325
+ self.use_vllm = attn_op_type == AttentionOpType.PagedAttention
326
+ graph = self.model.graph
327
+ new_inputs = []
328
+ for vi in graph.input:
329
+ if "input_ids" in vi.name:
330
+ vi_iid = helper.make_tensor_value_info(
331
+ vi.name,
332
+ elem_type=TensorProto.INT32 if not self.use_vllm else TensorProto.INT64,
333
+ shape=["batch_size", "seq_len"],
334
+ )
335
+ vi_step = helper.make_tensor_value_info(
336
+ "step",
337
+ elem_type=TensorProto.INT64,
338
+ shape=[1],
339
+ )
340
+ vi_pid = helper.make_tensor_value_info(
341
+ "position_ids",
342
+ elem_type=TensorProto.INT64,
343
+ shape=["batch_size", "seq_len"],
344
+ )
345
+ vi_mask = helper.make_tensor_value_info(
346
+ "attention_mask",
347
+ elem_type=TensorProto.INT32,
348
+ shape=["batch_size", "seq_len"],
349
+ )
350
+ vi_meta = helper.make_tensor_value_info(
351
+ "input_metadata",
352
+ elem_type=TensorProto.INT64,
353
+ shape=[1],
354
+ )
355
+ (
356
+ new_inputs.extend([vi_iid, vi_step, vi_mask])
357
+ if not self.use_vllm
358
+ else new_inputs.extend([vi_iid, vi_pid, vi_meta])
359
+ )
360
+ if self.use_attn:
361
+ if "past_key" in vi.name:
362
+ vi_cache = helper.make_tensor_value_info(
363
+ vi.name.replace("past_key", "past"),
364
+ elem_type=vi.type.tensor_type.elem_type,
365
+ shape=[
366
+ 2,
367
+ "batch_size",
368
+ self.num_attention_heads,
369
+ "past_seq_len",
370
+ self.hidden_size // self.num_attention_heads,
371
+ ],
372
+ )
373
+ new_inputs.extend([vi_cache])
374
+ elif self.use_vllm:
375
+ if "past_key" in vi.name:
376
+ vi_cache = helper.make_tensor_value_info(
377
+ vi.name,
378
+ elem_type=vi.type.tensor_type.elem_type,
379
+ shape=["num_blocks", "num_heads", "head_size_x", "block_size", "block_x"],
380
+ )
381
+ new_inputs.extend([vi_cache])
382
+ if "past_value" in vi.name:
383
+ vi_cache = helper.make_tensor_value_info(
384
+ vi.name,
385
+ elem_type=vi.type.tensor_type.elem_type,
386
+ shape=[
387
+ "num_blocks",
388
+ "num_heads",
389
+ "head_size",
390
+ "block_size",
391
+ ],
392
+ )
393
+ new_inputs.extend([vi_cache])
394
+ else:
395
+ if "past_key" in vi.name or "past_value" in vi.name:
396
+ vi_cache = helper.make_tensor_value_info(
397
+ vi.name,
398
+ elem_type=vi.type.tensor_type.elem_type,
399
+ shape=[
400
+ "batch_size",
401
+ self.num_attention_heads,
402
+ "past_seq_len",
403
+ self.hidden_size // self.num_attention_heads,
404
+ ],
405
+ )
406
+ new_inputs.extend([vi_cache])
407
+
408
+ graph.ClearField("input")
409
+ graph.input.extend(new_inputs)
410
+
411
+ new_outputs = []
412
+ for i, vi in enumerate(graph.output):
413
+ if i == 0:
414
+ new_outputs.extend([vi])
415
+ else:
416
+ if self.use_attn:
417
+ if "present_key" in vi.name:
418
+ vi_cache = helper.make_tensor_value_info(
419
+ vi.name.replace("present_key", "present"),
420
+ elem_type=vi.type.tensor_type.elem_type,
421
+ shape=[
422
+ 2,
423
+ "batch_size",
424
+ self.num_attention_heads,
425
+ "total_seq_len",
426
+ self.hidden_size // self.num_attention_heads,
427
+ ],
428
+ )
429
+ new_outputs.extend([vi_cache])
430
+ elif self.use_vllm:
431
+ pass
432
+ else:
433
+ vi_cache = helper.make_tensor_value_info(
434
+ vi.name,
435
+ elem_type=vi.type.tensor_type.elem_type,
436
+ shape=[
437
+ "batch_size",
438
+ self.num_attention_heads,
439
+ "total_seq_len",
440
+ self.hidden_size // self.num_attention_heads,
441
+ ],
442
+ )
443
+ new_outputs.extend([vi_cache])
444
+
445
+ graph.ClearField("output")
446
+ graph.output.extend(new_outputs)
447
+
448
+ def preprocess_onnx(self, attn_op_type: AttentionOpType):
449
+ function_name = None
450
+ for func in self.model.functions:
451
+ if func.name.endswith(self.func_name):
452
+ function_name = func.name
453
+ break
454
+ assert function_name is not None
455
+ self.unroll_function(function_name)
456
+ self.update_edges(self.get_phi2_edge_dict())
457
+ self.simplify_phi2_op_type()
458
+ self.remove_dropout_layer()
459
+ if attn_op_type == AttentionOpType.PagedAttention:
460
+ self.remove_lm_head_layer()
461
+ self.process_graph_io(attn_op_type)
462
+
463
+
464
+ class FissionTransformerEmbeddingPhi(Fission):
465
+ def __init__(
466
+ self,
467
+ model: OnnxModel,
468
+ ):
469
+ super().__init__(model, ["torch_nn_modules_sparse_Embedding_model_embed_tokens_1"])
470
+
471
+ def fuse(self, node, input_name_to_nodes, output_name_to_node):
472
+ logger.info("Optimizing %s...", node.name)
473
+
474
+ assert len(node.input) == 2
475
+ assert len(node.output) == 1
476
+
477
+ input = node.input[0]
478
+ output = node.output[0]
479
+
480
+ embedding = self.get_input_by_name(node, "embed_tokens.weight")
481
+
482
+ layer_known_edges_names = [input, output, embedding]
483
+
484
+ subgraph_nodes = [
485
+ helper.make_node(
486
+ "Gather",
487
+ inputs=[embedding, input],
488
+ outputs=[output],
489
+ name="Embedding_Gather",
490
+ ),
491
+ ]
492
+
493
+ self.set_unique_name_and_add_nodes(subgraph_nodes, 0, layer_known_edges_names)
494
+ self.nodes_to_remove.append(node)
495
+ self.prune_graph = True
496
+
497
+
498
+ class FissionTransformerLayerNormPhi(Fission):
499
+ def __init__(
500
+ self,
501
+ model: OnnxModel,
502
+ ):
503
+ super().__init__(model, ["torch_nn_modules_normalization_LayerNorm_model_final_layernorm_1"])
504
+
505
+ def fuse(self, node, input_name_to_nodes, output_name_to_node):
506
+ logger.info("Optimizing %s...", node.name)
507
+
508
+ assert len(node.input) == 3
509
+ assert len(node.output) == 1
510
+
511
+ input = node.input[0]
512
+ output = node.output[0]
513
+
514
+ ln_weight = self.get_input_by_name(node, "final_layernorm.weight")
515
+ ln_bias = self.get_input_by_name(node, "final_layernorm.bias")
516
+
517
+ layer_known_edges_names = [input, output, ln_weight, ln_bias]
518
+
519
+ subgraph_nodes = []
520
+ subgraph_nodes.extend(self.layernorm([input, ln_weight, ln_bias], [output], "Final"))
521
+
522
+ self.set_unique_name_and_add_nodes(subgraph_nodes, 99, layer_known_edges_names)
523
+
524
+ self.replace_fp32_value_info(input, ["batch_size", "seq_len", "hidden_size"])
525
+ self.replace_fp32_value_info(output, ["batch_size", "seq_len", "hidden_size"])
526
+
527
+ self.nodes_to_remove.append(node)
528
+ self.prune_graph = True
529
+
530
+
531
+ class FissionTransformerCausalLMHeadPhi(Fission):
532
+ def __init__(
533
+ self,
534
+ model: OnnxModel,
535
+ ):
536
+ super().__init__(model, ["torch_nn_modules_linear_Linear_lm_head_1"])
537
+
538
+ def fuse(self, node, input_name_to_nodes, output_name_to_node):
539
+ logger.info("Optimizing %s...", node.name)
540
+
541
+ assert len(node.input) == 5
542
+ assert len(node.output) == 1
543
+
544
+ input = node.input[2]
545
+ output = node.output[0]
546
+
547
+ fc_weight = self.process_initializer(self.get_input_by_name(node, "lm_head.weight"), ProcessGemmWFunc())
548
+ fc_bias = self.get_input_by_name(node, "lm_head.bias")
549
+
550
+ layer_known_edges_names = [input, output, fc_weight, fc_bias]
551
+
552
+ subgraph_nodes = []
553
+ subgraph_nodes.extend(self.gemm([input, fc_weight, fc_bias], [output], "LMHead_"))
554
+
555
+ self.set_unique_name_and_add_nodes(subgraph_nodes, 99, layer_known_edges_names)
556
+
557
+ self.replace_fp32_value_info(input, ["batch_size", "seq_len", "hidden_size"])
558
+ self.replace_fp32_value_info(output, ["batch_size", "seq_len", 51200])
559
+
560
+ self.nodes_to_remove.append(node)
561
+ self.prune_graph = True
562
+
563
+
564
+ class FissionTransformerBlockPhi(Fission):
565
+ def __init__(
566
+ self,
567
+ model: OnnxModel,
568
+ num_heads: int,
569
+ ):
570
+ self.num_heads = num_heads
571
+ max_num_layers = 32
572
+ self.func_to_layer_id = {}
573
+ nodes_to_find = []
574
+ for layer in range(max_num_layers):
575
+ func_name = f"modeling_phi_PhiDecoderLayer_model_layers_{layer}_1"
576
+ nodes_to_find.append(func_name)
577
+ self.func_to_layer_id[func_name] = layer
578
+
579
+ super().__init__(model, nodes_to_find)
580
+
581
+ def get_layer_id(self, node):
582
+ return self.func_to_layer_id[node.op_type]
583
+
584
+ def get_gqa_aux_nodes(self):
585
+ gqa_aux_nodes = [
586
+ helper.make_node(
587
+ "Cast",
588
+ inputs=["attention_mask"],
589
+ outputs=["mask_int64"],
590
+ name="Cast_gqa_aux_0",
591
+ to=TensorProto.INT64,
592
+ ),
593
+ helper.make_node(
594
+ "ReduceSum",
595
+ inputs=["mask_int64", "one"],
596
+ outputs=["mask_row_sums"],
597
+ name="ReduceSum_gqa_aux",
598
+ ),
599
+ helper.make_node(
600
+ "Sub",
601
+ inputs=["mask_row_sums", "one"],
602
+ outputs=["seqlens_k_int64"],
603
+ name="Sub_gqa_aux",
604
+ ),
605
+ helper.make_node(
606
+ "Cast",
607
+ inputs=["seqlens_k_int64"],
608
+ outputs=["seqlens_k"],
609
+ name="Cast_gqa_aux_1",
610
+ to=TensorProto.INT32,
611
+ ),
612
+ helper.make_node("Shape", inputs=["mask_int64"], outputs=["mask_shape"], name="Shape_gqa_aux_0"),
613
+ helper.make_node(
614
+ "Gather",
615
+ inputs=["mask_shape", "one"],
616
+ outputs=["total_seq_len_int64"],
617
+ name="Gather_gqa_aux_0",
618
+ axis=0,
619
+ ),
620
+ helper.make_node(
621
+ "Cast",
622
+ inputs=["total_seq_len_int64"],
623
+ outputs=["total_sequence_length"],
624
+ name="Cast_gqa_aux_2",
625
+ to=TensorProto.INT32,
626
+ ),
627
+ ]
628
+ return gqa_aux_nodes
629
+
630
+ def pack_qkv_gemm(self, q_w, k_w, v_w, q_b, k_b, v_b, weight_name, bias_name):
631
+ q_weight = self.model.get_initializer(q_w)
632
+ k_weight = self.model.get_initializer(k_w)
633
+ v_weight = self.model.get_initializer(v_w)
634
+ qw = np.transpose(NumpyHelper.to_array(q_weight), (1, 0))
635
+ kw = np.transpose(NumpyHelper.to_array(k_weight), (1, 0))
636
+ vw = np.transpose(NumpyHelper.to_array(v_weight), (1, 0))
637
+ qkv_weight = np.stack((qw, kw, vw), axis=1)
638
+
639
+ q_bias = self.model.get_initializer(q_b)
640
+ k_bias = self.model.get_initializer(k_b)
641
+ v_bias = self.model.get_initializer(v_b)
642
+ qb = NumpyHelper.to_array(q_bias)
643
+ kb = NumpyHelper.to_array(k_bias)
644
+ vb = NumpyHelper.to_array(v_bias)
645
+ qkv_bias = np.stack((qb, kb, vb), axis=0)
646
+
647
+ hidden_size = qkv_weight.shape[0]
648
+
649
+ weight = helper.make_tensor(
650
+ weight_name,
651
+ data_type=TensorProto.FLOAT,
652
+ dims=[hidden_size, hidden_size * 3],
653
+ vals=qkv_weight.flatten().tobytes(),
654
+ raw=True,
655
+ )
656
+ self.model.add_initializer(weight, self.this_graph_name)
657
+
658
+ bias = helper.make_tensor(
659
+ bias_name,
660
+ data_type=TensorProto.FLOAT,
661
+ dims=[hidden_size * 3],
662
+ vals=qkv_bias.flatten().tobytes(),
663
+ raw=True,
664
+ )
665
+ self.model.add_initializer(bias, self.this_graph_name)
666
+
667
+ self.add_fp32_value_info(weight.name)
668
+ self.add_fp32_value_info(bias.name)
669
+
670
+ return weight_name, bias_name
671
+
672
+ def fuse(
673
+ self,
674
+ node,
675
+ input_name_to_nodes,
676
+ output_name_to_node,
677
+ ):
678
+ logger.info("Optimizing %s...", node.name)
679
+
680
+ logger.info(f"AttentionOpType: {self.attn_op_type}")
681
+
682
+ layer_id = self.get_layer_id(node)
683
+
684
+ i_hidden_states = node.input[0]
685
+ i_key_cache = self.get_input_by_name(node, "past_key")
686
+ i_value_cache = self.get_input_by_name(node, "past_value")
687
+
688
+ o_hidden_states = node.output[-1]
689
+ o_key_cache = self.get_output_by_name(node, "present_key")
690
+ o_value_cache = self.get_output_by_name(node, "present_value")
691
+
692
+ ln_weight = self.get_input_by_name(node, "input_layernorm.weight")
693
+ ln_bias = self.get_input_by_name(node, "input_layernorm.bias")
694
+
695
+ attn_q_weight, attn_q_bias, attn_k_weight, attn_k_bias, attn_v_weight, attn_v_bias = (
696
+ None,
697
+ None,
698
+ None,
699
+ None,
700
+ None,
701
+ None,
702
+ )
703
+ attn_qkv_weight, attn_qkv_bias = None, None
704
+ cos_cache, sin_cache = None, None
705
+
706
+ if self.attn_op_type != AttentionOpType.Attention:
707
+ attn_q_weight = self.process_initializer(
708
+ self.get_input_by_name(node, "self_attn.q_proj.weight"), ProcessGemmWFunc()
709
+ )
710
+ attn_k_weight = self.process_initializer(
711
+ self.get_input_by_name(node, "self_attn.k_proj.weight"), ProcessGemmWFunc()
712
+ )
713
+ attn_v_weight = self.process_initializer(
714
+ self.get_input_by_name(node, "self_attn.v_proj.weight"), ProcessGemmWFunc()
715
+ )
716
+ attn_q_bias = self.get_input_by_name(node, "self_attn.q_proj.bias")
717
+ attn_k_bias = self.get_input_by_name(node, "self_attn.k_proj.bias")
718
+ attn_v_bias = self.get_input_by_name(node, "self_attn.v_proj.bias")
719
+
720
+ cos_cache = self.process_initializer(
721
+ self.get_input_by_name(node, "rotary_emb.cos_cached"), ProcessRotCacheFunc()
722
+ )
723
+ sin_cache = self.process_initializer(
724
+ self.get_input_by_name(node, "rotary_emb.sin_cached"), ProcessRotCacheFunc()
725
+ )
726
+ else:
727
+ attn_qkv_weight, attn_qkv_bias = self.pack_qkv_gemm(
728
+ self.get_input_by_name(node, "self_attn.q_proj.weight"),
729
+ self.get_input_by_name(node, "self_attn.k_proj.weight"),
730
+ self.get_input_by_name(node, "self_attn.v_proj.weight"),
731
+ self.get_input_by_name(node, "self_attn.q_proj.bias"),
732
+ self.get_input_by_name(node, "self_attn.k_proj.bias"),
733
+ self.get_input_by_name(node, "self_attn.v_proj.bias"),
734
+ self.get_uname(layer_id, "attn_qkv_weight"),
735
+ self.get_uname(layer_id, "attn_qkv_bias"),
736
+ )
737
+
738
+ attn_out_weight = self.process_initializer(
739
+ self.get_input_by_name(node, "self_attn.dense.weight"), ProcessGemmWFunc()
740
+ )
741
+ attn_out_bias = self.get_input_by_name(node, "self_attn.dense.bias")
742
+
743
+ mlp_fc1_weight = self.process_initializer(self.get_input_by_name(node, "mlp.fc1.weight"), ProcessGemmWFunc())
744
+ mlp_fc2_weight = self.process_initializer(self.get_input_by_name(node, "mlp.fc2.weight"), ProcessGemmWFunc())
745
+ mlp_fc1_bias = self.get_input_by_name(node, "mlp.fc1.bias")
746
+ mlp_fc2_bias = self.get_input_by_name(node, "mlp.fc2.bias")
747
+
748
+ layer_known_edges_names = []
749
+ layer_known_edges_names.extend([i_hidden_states, i_key_cache, i_value_cache])
750
+ layer_known_edges_names.extend([o_hidden_states, o_key_cache, o_value_cache])
751
+ layer_known_edges_names.extend([ln_weight, ln_bias])
752
+ if self.attn_op_type != AttentionOpType.Attention:
753
+ layer_known_edges_names.extend(
754
+ [
755
+ attn_q_weight,
756
+ attn_q_bias,
757
+ attn_k_weight,
758
+ attn_k_bias,
759
+ attn_v_weight,
760
+ attn_v_bias,
761
+ cos_cache,
762
+ sin_cache,
763
+ ]
764
+ )
765
+ else:
766
+ layer_known_edges_names.extend([attn_qkv_weight, attn_qkv_bias])
767
+ layer_known_edges_names.extend(
768
+ [attn_out_weight, attn_out_bias, mlp_fc1_weight, mlp_fc1_bias, mlp_fc2_weight, mlp_fc2_bias]
769
+ )
770
+ layer_known_edges_names.extend(
771
+ ["attention_mask", "step", "seqlens_k", "total_sequence_length", "input_metadata", "position_ids"]
772
+ )
773
+
774
+ subgraph_nodes = []
775
+ subgraph_nodes.extend(self.layernorm([i_hidden_states, ln_weight, ln_bias], ["ln_out"]))
776
+ subgraph_nodes.extend(self.gemm(["attn_out", attn_out_weight, attn_out_bias], ["attn_add_out"], "OutProj_"))
777
+ subgraph_nodes.extend(self.gemm(["ln_out", mlp_fc1_weight, mlp_fc1_bias], ["fc1_out"], "FC1_"))
778
+ subgraph_nodes.extend(self.fastgelu(["fc1_out"], ["gelu_out"]))
779
+ subgraph_nodes.extend(self.gemm(["gelu_out", mlp_fc2_weight, mlp_fc2_bias], ["fc2_out"], "FC2_"))
780
+ subgraph_nodes.extend(self.add(["attn_add_out", "fc2_out"], ["residual_1_out"], "Residual_1"))
781
+ subgraph_nodes.extend(self.add([i_hidden_states, "residual_1_out"], [o_hidden_states], "Residual_2"))
782
+ if self.attn_op_type != AttentionOpType.Attention:
783
+ subgraph_nodes.extend(self.gemm(["ln_out", attn_q_weight, attn_q_bias], ["query"], "Q_"))
784
+ subgraph_nodes.extend(self.gemm(["ln_out", attn_k_weight, attn_k_bias], ["key"], "K_"))
785
+ subgraph_nodes.extend(self.gemm(["ln_out", attn_v_weight, attn_v_bias], ["value"], "V_"))
786
+ # vllm engine requires full position ids as the input
787
+ pos_ids_name = "position_ids" if self.attn_op_type == AttentionOpType.PagedAttention else "step"
788
+ subgraph_nodes.extend(self.rotary(["query", pos_ids_name, cos_cache, sin_cache], ["query_rot"], "Q_"))
789
+ subgraph_nodes.extend(self.rotary(["key", pos_ids_name, cos_cache, sin_cache], ["key_rot"], "K_"))
790
+ if self.attn_op_type == AttentionOpType.MultiHeadAttention:
791
+ subgraph_nodes.extend(
792
+ self.mha(
793
+ ["query_rot", "key_rot", "value", "", "attention_mask", "", i_key_cache, i_value_cache],
794
+ ["attn_out", o_key_cache, o_value_cache],
795
+ )
796
+ )
797
+ elif self.attn_op_type == AttentionOpType.GroupQueryAttention:
798
+ subgraph_nodes.extend(
799
+ self.gqa(
800
+ [
801
+ "query_rot",
802
+ "key_rot",
803
+ "value",
804
+ i_key_cache,
805
+ i_value_cache,
806
+ "seqlens_k",
807
+ "total_sequence_length",
808
+ ],
809
+ ["attn_out", o_key_cache, o_value_cache],
810
+ )
811
+ )
812
+ if layer_id == 0:
813
+ gqa_aux_nodes = self.get_gqa_aux_nodes()
814
+ for new_node in gqa_aux_nodes:
815
+ self.nodes_to_add.append(new_node)
816
+ self.node_name_to_graph_name[new_node.name] = self.this_graph_name
817
+ self.model.add_initializer(
818
+ numpy_helper.from_array(np.array([1], dtype="int64"), name="one"), self.this_graph_name
819
+ )
820
+ elif self.attn_op_type == AttentionOpType.PagedAttention:
821
+ subgraph_nodes.extend(
822
+ self.paged_attn(
823
+ ["query_rot", "key_rot", "value", i_key_cache, i_value_cache, "input_metadata"],
824
+ ["attn_out"],
825
+ )
826
+ )
827
+ else:
828
+ past_name = f"past_{layer_id}"
829
+ present_name = f"present_{layer_id}"
830
+ layer_known_edges_names.extend([past_name, present_name])
831
+ subgraph_nodes.extend(
832
+ self.attention(
833
+ ["ln_out", attn_qkv_weight, attn_qkv_bias, "attention_mask", past_name], ["attn_out", present_name]
834
+ )
835
+ )
836
+
837
+ self.set_unique_name_and_add_nodes(subgraph_nodes, layer_id, layer_known_edges_names)
838
+
839
+ self.replace_fp32_value_info(i_hidden_states, ["batch_size", "seq_len", "hidden_size"])
840
+ self.replace_fp32_value_info(o_hidden_states, ["batch_size", "seq_len", "hidden_size"])
841
+
842
+ self.nodes_to_remove.append(node)
843
+ self.prune_graph = True
844
+
845
+
846
+ class PhiOnnxModel(OnnxModel):
847
+ def __init__(self, model: ModelProto, num_heads: int, hidden_size: int):
848
+ super().__init__(model)
849
+ self.phi2_preprocessor = Phi2PreProcessor(self.model, num_heads, hidden_size)
850
+ self.fission_transformer_block = FissionTransformerBlockPhi(self, num_heads)
851
+ self.fission_causal_lm_head = FissionTransformerCausalLMHeadPhi(self)
852
+ self.fission_transformer_layernorm = FissionTransformerLayerNormPhi(self)
853
+ self.fission_transformer_embedding = FissionTransformerEmbeddingPhi(self)
854
+
855
+ def optimize(self, options: FusionOptions | None = None, add_dynamic_axes: bool = False):
856
+ assert options is not None
857
+ attn_op_type = options.attention_op_type
858
+
859
+ self.fission_transformer_block.set_attention_op_type(attn_op_type)
860
+
861
+ self.phi2_preprocessor.preprocess_onnx(attn_op_type)
862
+
863
+ self.fission_transformer_block.apply()
864
+ self.fission_transformer_layernorm.apply()
865
+ self.fission_causal_lm_head.apply()
866
+ self.fission_transformer_embedding.apply()
867
+
868
+ super().prune_graph()
869
+
870
+ # SLN ctor is placed here intentionally to delay the symbolic shape inference
871
+ self.fuse_sln = FusionSkipLayerNormalization(self)
872
+ self.fuse_bias_sln = FusionBiasSkipLayerNormalization(self)
873
+ self.fuse_sln.apply()
874
+ self.fuse_bias_sln.apply()
875
+
876
+ def get_fused_operator_statistics(self):
877
+ """
878
+ Returns node count of fused operators.
879
+ """
880
+ op_count = {}
881
+ ops = [
882
+ "Attention",
883
+ "MultiHeadAttention",
884
+ "GroupQueryAttention",
885
+ "PagedAttention",
886
+ "Gelu",
887
+ "BiasGelu",
888
+ "FastGelu",
889
+ "LayerNormalization",
890
+ "SkipLayerNormalization",
891
+ ]
892
+ for op in ops:
893
+ nodes = self.get_nodes_by_op_type(op)
894
+ op_count[op] = len(nodes)
895
+
896
+ logger.info(f"Optimized operators: {op_count}")
897
+ return op_count
898
+
899
+ def is_fully_optimized(self, fused_op_count=None):
900
+ """
901
+ Returns True when the model is fully optimized.
902
+ """
903
+ if fused_op_count is None:
904
+ fused_op_count = self.get_fused_operator_statistics()
905
+
906
+ def op_count(op_name: str):
907
+ return fused_op_count.get(op_name) or 0
908
+
909
+ attention = (
910
+ op_count("Attention")
911
+ + op_count("MultiHeadAttention")
912
+ + op_count("GroupQueryAttention")
913
+ + op_count("PagedAttention")
914
+ )
915
+ gelu = op_count("Gelu") + op_count("BiasGelu") + op_count("FastGelu")
916
+ layer_norm = op_count("LayerNormalization") + op_count("SkipLayerNormalization")
917
+
918
+ is_perfect = (attention > 0) and (attention == gelu) and (layer_norm >= attention)
919
+
920
+ if layer_norm == 0:
921
+ logger.debug("Layer Normalization not fused")
922
+
923
+ if gelu == 0:
924
+ logger.debug("Gelu (or FastGelu) not fused")
925
+
926
+ if attention == 0:
927
+ logger.warning("Attention (or MultiHeadAttention) not fused")
928
+
929
+ return is_perfect