onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (322) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6121 -0
  4. onnxruntime/__init__.py +418 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +175 -0
  7. onnxruntime/backend/backend_rep.py +52 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/build_and_package_info.py +2 -0
  13. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  14. onnxruntime/capi/onnxruntime.dll +0 -0
  15. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  16. onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
  17. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  18. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  19. onnxruntime/capi/onnxruntime_validation.py +154 -0
  20. onnxruntime/capi/version_info.py +2 -0
  21. onnxruntime/datasets/__init__.py +18 -0
  22. onnxruntime/datasets/logreg_iris.onnx +0 -0
  23. onnxruntime/datasets/mul_1.onnx +0 -0
  24. onnxruntime/datasets/sigmoid.onnx +13 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  27. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  28. onnxruntime/quantization/__init__.py +19 -0
  29. onnxruntime/quantization/base_quantizer.py +529 -0
  30. onnxruntime/quantization/calibrate.py +1267 -0
  31. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  32. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  33. onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
  34. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  35. onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
  36. onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
  37. onnxruntime/quantization/fusions/__init__.py +4 -0
  38. onnxruntime/quantization/fusions/fusion.py +311 -0
  39. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  40. onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
  41. onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
  42. onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
  43. onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
  44. onnxruntime/quantization/neural_compressor/__init__.py +1 -0
  45. onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
  46. onnxruntime/quantization/neural_compressor/util.py +80 -0
  47. onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
  48. onnxruntime/quantization/onnx_model.py +600 -0
  49. onnxruntime/quantization/onnx_quantizer.py +1163 -0
  50. onnxruntime/quantization/operators/__init__.py +2 -0
  51. onnxruntime/quantization/operators/activation.py +119 -0
  52. onnxruntime/quantization/operators/argmax.py +18 -0
  53. onnxruntime/quantization/operators/attention.py +73 -0
  54. onnxruntime/quantization/operators/base_operator.py +26 -0
  55. onnxruntime/quantization/operators/binary_op.py +72 -0
  56. onnxruntime/quantization/operators/concat.py +62 -0
  57. onnxruntime/quantization/operators/conv.py +260 -0
  58. onnxruntime/quantization/operators/direct_q8.py +78 -0
  59. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  60. onnxruntime/quantization/operators/gather.py +64 -0
  61. onnxruntime/quantization/operators/gavgpool.py +62 -0
  62. onnxruntime/quantization/operators/gemm.py +172 -0
  63. onnxruntime/quantization/operators/lstm.py +121 -0
  64. onnxruntime/quantization/operators/matmul.py +231 -0
  65. onnxruntime/quantization/operators/maxpool.py +34 -0
  66. onnxruntime/quantization/operators/norm.py +40 -0
  67. onnxruntime/quantization/operators/pad.py +172 -0
  68. onnxruntime/quantization/operators/pooling.py +67 -0
  69. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  70. onnxruntime/quantization/operators/resize.py +34 -0
  71. onnxruntime/quantization/operators/softmax.py +74 -0
  72. onnxruntime/quantization/operators/split.py +63 -0
  73. onnxruntime/quantization/operators/where.py +87 -0
  74. onnxruntime/quantization/preprocess.py +141 -0
  75. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  76. onnxruntime/quantization/qdq_quantizer.py +1477 -0
  77. onnxruntime/quantization/quant_utils.py +1051 -0
  78. onnxruntime/quantization/quantize.py +953 -0
  79. onnxruntime/quantization/registry.py +110 -0
  80. onnxruntime/quantization/shape_inference.py +204 -0
  81. onnxruntime/quantization/static_quantize_runner.py +256 -0
  82. onnxruntime/quantization/tensor_quant_overrides.py +520 -0
  83. onnxruntime/tools/__init__.py +10 -0
  84. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  85. onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
  86. onnxruntime/tools/file_utils.py +47 -0
  87. onnxruntime/tools/logger.py +11 -0
  88. onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
  89. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  90. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
  91. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  92. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  93. onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
  94. onnxruntime/tools/offline_tuning.py +169 -0
  95. onnxruntime/tools/onnx_model_utils.py +416 -0
  96. onnxruntime/tools/onnx_randomizer.py +85 -0
  97. onnxruntime/tools/onnxruntime_test.py +164 -0
  98. onnxruntime/tools/optimize_onnx_model.py +56 -0
  99. onnxruntime/tools/ort_format_model/__init__.py +27 -0
  100. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  140. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  141. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  142. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  143. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  144. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  145. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  146. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  147. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  148. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  149. onnxruntime/tools/ort_format_model/types.py +85 -0
  150. onnxruntime/tools/ort_format_model/utils.py +61 -0
  151. onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
  152. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  153. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  154. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  155. onnxruntime/tools/qnn/add_trans_cast.py +292 -0
  156. onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
  157. onnxruntime/tools/qnn/preprocess.py +165 -0
  158. onnxruntime/tools/reduced_build_config_parser.py +203 -0
  159. onnxruntime/tools/remove_initializer_from_input.py +37 -0
  160. onnxruntime/tools/symbolic_shape_infer.py +3094 -0
  161. onnxruntime/tools/update_onnx_opset.py +31 -0
  162. onnxruntime/transformers/__init__.py +8 -0
  163. onnxruntime/transformers/affinity_helper.py +40 -0
  164. onnxruntime/transformers/benchmark.py +942 -0
  165. onnxruntime/transformers/benchmark_helper.py +643 -0
  166. onnxruntime/transformers/bert_perf_test.py +629 -0
  167. onnxruntime/transformers/bert_test_data.py +641 -0
  168. onnxruntime/transformers/compare_bert_results.py +256 -0
  169. onnxruntime/transformers/constants.py +47 -0
  170. onnxruntime/transformers/convert_generation.py +3605 -0
  171. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  172. onnxruntime/transformers/convert_to_packing_mode.py +385 -0
  173. onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
  174. onnxruntime/transformers/float16.py +501 -0
  175. onnxruntime/transformers/fusion_attention.py +1189 -0
  176. onnxruntime/transformers/fusion_attention_clip.py +340 -0
  177. onnxruntime/transformers/fusion_attention_sam2.py +533 -0
  178. onnxruntime/transformers/fusion_attention_unet.py +1307 -0
  179. onnxruntime/transformers/fusion_attention_vae.py +300 -0
  180. onnxruntime/transformers/fusion_bart_attention.py +435 -0
  181. onnxruntime/transformers/fusion_base.py +141 -0
  182. onnxruntime/transformers/fusion_bias_add.py +57 -0
  183. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  184. onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
  185. onnxruntime/transformers/fusion_conformer_attention.py +222 -0
  186. onnxruntime/transformers/fusion_constant_fold.py +144 -0
  187. onnxruntime/transformers/fusion_embedlayer.py +810 -0
  188. onnxruntime/transformers/fusion_fastgelu.py +492 -0
  189. onnxruntime/transformers/fusion_gelu.py +258 -0
  190. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  191. onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
  192. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  193. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  194. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  195. onnxruntime/transformers/fusion_group_norm.py +180 -0
  196. onnxruntime/transformers/fusion_layernorm.py +489 -0
  197. onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
  198. onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
  199. onnxruntime/transformers/fusion_options.py +340 -0
  200. onnxruntime/transformers/fusion_qordered_attention.py +420 -0
  201. onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
  202. onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
  203. onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
  204. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  205. onnxruntime/transformers/fusion_reshape.py +173 -0
  206. onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
  207. onnxruntime/transformers/fusion_shape.py +109 -0
  208. onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
  209. onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
  210. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  211. onnxruntime/transformers/fusion_transpose.py +167 -0
  212. onnxruntime/transformers/fusion_utils.py +321 -0
  213. onnxruntime/transformers/huggingface_models.py +74 -0
  214. onnxruntime/transformers/import_utils.py +20 -0
  215. onnxruntime/transformers/io_binding_helper.py +487 -0
  216. onnxruntime/transformers/large_model_exporter.py +395 -0
  217. onnxruntime/transformers/machine_info.py +230 -0
  218. onnxruntime/transformers/metrics.py +163 -0
  219. onnxruntime/transformers/models/bart/__init__.py +12 -0
  220. onnxruntime/transformers/models/bart/export.py +98 -0
  221. onnxruntime/transformers/models/bert/__init__.py +12 -0
  222. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  223. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  224. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  225. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
  226. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
  227. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  228. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  229. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  230. onnxruntime/transformers/models/llama/__init__.py +12 -0
  231. onnxruntime/transformers/models/llama/benchmark.py +700 -0
  232. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  233. onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
  234. onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
  235. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  236. onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
  237. onnxruntime/transformers/models/llama/llama_parity.py +343 -0
  238. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  239. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  240. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  241. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  242. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  243. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  244. onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
  245. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  246. onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
  247. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  248. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  249. onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
  250. onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
  251. onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
  252. onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
  253. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  254. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  255. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  256. onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
  257. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
  258. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  259. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  260. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
  261. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  262. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
  263. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
  264. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  265. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
  266. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
  267. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
  268. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
  269. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  270. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  271. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  272. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
  273. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  274. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  275. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  276. onnxruntime/transformers/models/t5/__init__.py +12 -0
  277. onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
  278. onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
  279. onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
  280. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
  281. onnxruntime/transformers/models/t5/t5_helper.py +302 -0
  282. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  283. onnxruntime/transformers/models/whisper/benchmark.py +585 -0
  284. onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
  285. onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
  286. onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
  287. onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
  288. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  289. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
  290. onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
  291. onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
  292. onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
  293. onnxruntime/transformers/onnx_exporter.py +719 -0
  294. onnxruntime/transformers/onnx_model.py +1636 -0
  295. onnxruntime/transformers/onnx_model_bart.py +141 -0
  296. onnxruntime/transformers/onnx_model_bert.py +488 -0
  297. onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
  298. onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
  299. onnxruntime/transformers/onnx_model_clip.py +42 -0
  300. onnxruntime/transformers/onnx_model_conformer.py +32 -0
  301. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  302. onnxruntime/transformers/onnx_model_mmdit.py +112 -0
  303. onnxruntime/transformers/onnx_model_phi.py +929 -0
  304. onnxruntime/transformers/onnx_model_sam2.py +137 -0
  305. onnxruntime/transformers/onnx_model_t5.py +985 -0
  306. onnxruntime/transformers/onnx_model_tnlr.py +226 -0
  307. onnxruntime/transformers/onnx_model_unet.py +258 -0
  308. onnxruntime/transformers/onnx_model_vae.py +42 -0
  309. onnxruntime/transformers/onnx_utils.py +55 -0
  310. onnxruntime/transformers/optimizer.py +620 -0
  311. onnxruntime/transformers/past_helper.py +149 -0
  312. onnxruntime/transformers/profile_result_processor.py +358 -0
  313. onnxruntime/transformers/profiler.py +434 -0
  314. onnxruntime/transformers/quantize_helper.py +76 -0
  315. onnxruntime/transformers/shape_infer_helper.py +121 -0
  316. onnxruntime/transformers/shape_optimizer.py +400 -0
  317. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  318. onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
  319. onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
  320. onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
  321. onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
  322. onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,205 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+
6
+ import glob
7
+ import os
8
+
9
+ import requests
10
+
11
+ TFMODELS = {
12
+ "bert-base-uncased": (
13
+ "bert",
14
+ "BertConfig",
15
+ "",
16
+ "https://storage.googleapis.com/bert_models/2018_10_18/uncased_L-12_H-768_A-12.zip",
17
+ ),
18
+ "bert-base-cased": (
19
+ "bert",
20
+ "BertConfig",
21
+ "",
22
+ "https://storage.googleapis.com/bert_models/2019_05_30/wwm_cased_L-24_H-1024_A-16.zip",
23
+ ),
24
+ "bert-large-uncased": (
25
+ "bert",
26
+ "BertConfig",
27
+ "",
28
+ "https://storage.googleapis.com/bert_models/2018_10_18/uncased_L-24_H-1024_A-16.zip",
29
+ ),
30
+ "albert-base": (
31
+ "albert",
32
+ "AlbertConfig",
33
+ "",
34
+ "https://storage.googleapis.com/albert_models/albert_base_v1.tar.gz",
35
+ ),
36
+ "albert-large": (
37
+ "albert",
38
+ "AlbertConfig",
39
+ "",
40
+ "https://storage.googleapis.com/albert_models/albert_large_v1.tar.gz",
41
+ ),
42
+ "gpt-2-117M": (
43
+ "gpt2",
44
+ "GPT2Config",
45
+ "GPT2Model",
46
+ "https://storage.googleapis.com/gpt-2/models/117M",
47
+ ),
48
+ "gpt-2-124M": (
49
+ "gpt2",
50
+ "GPT2Config",
51
+ "GPT2Model",
52
+ "https://storage.googleapis.com/gpt-2/models/124M",
53
+ ),
54
+ }
55
+
56
+
57
+ def download_compressed_file(tf_ckpt_url, ckpt_dir):
58
+ r = requests.get(tf_ckpt_url)
59
+ compressed_file_name = tf_ckpt_url.split("/")[-1]
60
+ compressed_file_dir = os.path.join(ckpt_dir, compressed_file_name)
61
+ with open(compressed_file_dir, "wb") as f:
62
+ f.write(r.content)
63
+ return compressed_file_dir
64
+
65
+
66
+ def get_ckpt_prefix_path(ckpt_dir):
67
+ # get prefix
68
+ sub_folder_dir = None
69
+ for o in os.listdir(ckpt_dir):
70
+ sub_folder_dir = os.path.join(ckpt_dir, o)
71
+ break
72
+ if os.path.isfile(sub_folder_dir):
73
+ sub_folder_dir = ckpt_dir
74
+ unique_file_name = str(glob.glob(sub_folder_dir + "/*data-00000-of-00001"))
75
+ prefix = (unique_file_name.rpartition(".")[0]).split("/")[-1]
76
+
77
+ return os.path.join(sub_folder_dir, prefix)
78
+
79
+
80
+ def download_tf_checkpoint(model_name, tf_models_dir="tf_models"):
81
+ import pathlib # noqa: PLC0415
82
+
83
+ base_dir = os.path.join(pathlib.Path(__file__).parent.absolute(), tf_models_dir)
84
+ ckpt_dir = os.path.join(base_dir, model_name)
85
+
86
+ if not os.path.exists(ckpt_dir):
87
+ os.makedirs(ckpt_dir)
88
+
89
+ tf_ckpt_url = TFMODELS[model_name][3]
90
+
91
+ import re # noqa: PLC0415
92
+
93
+ if re.search(".zip$", tf_ckpt_url) is not None:
94
+ zip_dir = download_compressed_file(tf_ckpt_url, ckpt_dir)
95
+
96
+ # unzip file
97
+ import zipfile # noqa: PLC0415
98
+
99
+ with zipfile.ZipFile(zip_dir, "r") as zip_ref:
100
+ zip_ref.extractall(ckpt_dir)
101
+ os.remove(zip_dir)
102
+
103
+ return get_ckpt_prefix_path(ckpt_dir)
104
+
105
+ elif re.search(".tar.gz$", tf_ckpt_url) is not None:
106
+ tar_dir = download_compressed_file(tf_ckpt_url, ckpt_dir)
107
+
108
+ # untar file
109
+ import tarfile # noqa: PLC0415
110
+
111
+ with tarfile.open(tar_dir, "r") as tar_ref:
112
+ tar_ref.extractall(ckpt_dir)
113
+ os.remove(tar_dir)
114
+
115
+ return get_ckpt_prefix_path(ckpt_dir)
116
+
117
+ else:
118
+ for filename in [
119
+ "checkpoint",
120
+ "model.ckpt.data-00000-of-00001",
121
+ "model.ckpt.index",
122
+ "model.ckpt.meta",
123
+ ]:
124
+ r = requests.get(tf_ckpt_url + "/" + filename)
125
+ with open(os.path.join(ckpt_dir, filename), "wb") as f:
126
+ f.write(r.content)
127
+
128
+ return get_ckpt_prefix_path(ckpt_dir)
129
+
130
+
131
+ def init_pytorch_model(model_name, tf_checkpoint_path):
132
+ config_name = TFMODELS[model_name][1]
133
+ config_module = __import__("transformers", fromlist=[config_name])
134
+ model_config = getattr(config_module, config_name)
135
+
136
+ parent_path = tf_checkpoint_path.rpartition("/")[0]
137
+ config_path = glob.glob(parent_path + "/*config.json")
138
+ config = model_config() if len(config_path) == 0 else model_config.from_json_file(str(config_path[0]))
139
+
140
+ if not TFMODELS[model_name][2]:
141
+ from transformers import AutoModelForPreTraining # noqa: PLC0415
142
+
143
+ init_model = AutoModelForPreTraining.from_config(config)
144
+ else:
145
+ model_categroy_name = TFMODELS[model_name][2]
146
+ module = __import__("transformers", fromlist=[model_categroy_name])
147
+ model_categroy = getattr(module, model_categroy_name)
148
+ init_model = model_categroy(config)
149
+ return config, init_model
150
+
151
+
152
+ def convert_tf_checkpoint_to_pytorch(model_name, config, init_model, tf_checkpoint_path, is_tf2):
153
+ load_tf_weight_func_name = "load_tf_weights_in_" + TFMODELS[model_name][0]
154
+
155
+ module = __import__("transformers", fromlist=[load_tf_weight_func_name])
156
+
157
+ if is_tf2 is False:
158
+ load_tf_weight_func = getattr(module, load_tf_weight_func_name)
159
+ else:
160
+ if TFMODELS[model_name][0] != "bert":
161
+ raise NotImplementedError("Only support tf2 ckeckpoint for Bert model")
162
+ from transformers import convert_bert_original_tf2_checkpoint_to_pytorch # noqa: PLC0415
163
+
164
+ load_tf_weight_func = convert_bert_original_tf2_checkpoint_to_pytorch.load_tf2_weights_in_bert
165
+
166
+ # Expect transformers team will unify the order of signature in the future
167
+ model = (
168
+ load_tf_weight_func(init_model, config, tf_checkpoint_path)
169
+ if is_tf2 is False
170
+ else load_tf_weight_func(init_model, tf_checkpoint_path, config)
171
+ )
172
+ model.eval()
173
+ return model
174
+
175
+
176
+ def tf2pt_pipeline(model_name, is_tf2=False):
177
+ if model_name not in TFMODELS:
178
+ raise NotImplementedError(model_name + " not implemented")
179
+ tf_checkpoint_path = download_tf_checkpoint(model_name)
180
+ config, init_model = init_pytorch_model(model_name, tf_checkpoint_path)
181
+ model = convert_tf_checkpoint_to_pytorch(model_name, config, init_model, tf_checkpoint_path, is_tf2)
182
+ # Could then use the model in Benchmark
183
+ return config, model
184
+
185
+
186
+ def tf2pt_pipeline_test():
187
+ # For test on linux only
188
+ import logging # noqa: PLC0415
189
+
190
+ import torch # noqa: PLC0415
191
+
192
+ logger = logging.getLogger("")
193
+ for model_name in TFMODELS:
194
+ config, model = tf2pt_pipeline(model_name)
195
+ assert config.model_type is TFMODELS[model_name][0]
196
+
197
+ input = torch.randint(low=0, high=config.vocab_size - 1, size=(4, 128), dtype=torch.long)
198
+ try:
199
+ model(input)
200
+ except RuntimeError as e:
201
+ logger.exception(e)
202
+
203
+
204
+ if __name__ == "__main__":
205
+ tf2pt_pipeline_test()
@@ -0,0 +1,385 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+
6
+ import argparse
7
+ import logging
8
+ import os
9
+
10
+ from constants import (
11
+ AttentionInputIDs,
12
+ AttentionOutputIDs,
13
+ MultiHeadAttentionInputIDs,
14
+ MultiHeadAttentionOutputIDs,
15
+ Operators,
16
+ )
17
+ from onnx import helper, load_model
18
+ from onnx_model import NodeProto, OnnxModel
19
+ from shape_infer_helper import SymbolicShapeInferenceHelper
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ class PackingAttentionBase:
25
+ def __init__(self, model: OnnxModel, attention_op_type: str):
26
+ self.model: OnnxModel = model
27
+ self.nodes_to_remove: list = []
28
+ self.nodes_to_add: list = []
29
+ self.prune_graph: bool = False
30
+ self.node_name_to_graph_name: dict = {}
31
+ self.this_graph_name: str = self.model.model.graph.name
32
+ self.attention_op_type = attention_op_type
33
+ self.attention_nodes = self.model.get_nodes_by_op_type(attention_op_type)
34
+
35
+ def _try_getting_attention_mask(self) -> str | None:
36
+ mask_index = (
37
+ AttentionInputIDs.MASK_INDEX
38
+ if self.attention_op_type == Operators.ATTENTION
39
+ else MultiHeadAttentionInputIDs.KEY_PADDING_MASK
40
+ )
41
+ first_attention_node = self._try_getting_first_attention()
42
+ # check if attention has mask
43
+ if not first_attention_node or len(first_attention_node.input) <= mask_index:
44
+ return None
45
+
46
+ attention_mask = first_attention_node.input[mask_index]
47
+
48
+ # check if all attention nodes have same mask
49
+ for node in self.attention_nodes:
50
+ if len(node.input) <= mask_index or node.input[mask_index] != attention_mask:
51
+ return None
52
+
53
+ return attention_mask
54
+
55
+ def _try_getting_first_attention(self) -> NodeProto | None:
56
+ if len(self.attention_nodes) <= 0:
57
+ return None
58
+
59
+ return self.attention_nodes[0]
60
+
61
+ def _try_getting_last_layernorm(self) -> NodeProto | None:
62
+ last_layernorm_node = None
63
+ for node in self.model.nodes():
64
+ if node.op_type == Operators.LAYERNORM or node.op_type == Operators.SKIPLAYERNORM:
65
+ last_layernorm_node = node
66
+ return last_layernorm_node
67
+
68
+ def _are_attentions_supported(self) -> bool:
69
+ raise NotImplementedError()
70
+
71
+ def _insert_removepadding_node(self, inputs: list[str], outputs: list[str]) -> None:
72
+ new_node = helper.make_node(
73
+ Operators.REMOVEPADDING,
74
+ inputs=inputs,
75
+ outputs=outputs,
76
+ name=self.model.create_node_name(Operators.REMOVEPADDING),
77
+ )
78
+
79
+ new_node.domain = "com.microsoft"
80
+ self.nodes_to_add.append(new_node)
81
+ self.node_name_to_graph_name[new_node.name] = self.this_graph_name
82
+
83
+ def _insert_restorepadding_node(self, inputs: list[str], outputs: list[str]) -> None:
84
+ new_node = helper.make_node(
85
+ Operators.RESTOREPADDING,
86
+ inputs=inputs,
87
+ outputs=outputs,
88
+ name=self.model.create_node_name(Operators.RESTOREPADDING),
89
+ )
90
+
91
+ new_node.domain = "com.microsoft"
92
+ self.nodes_to_add.append(new_node)
93
+ self.node_name_to_graph_name[new_node.name] = self.this_graph_name
94
+
95
+ def _replace_attention_with_packing_attention(self, token_offset: str, cumulative_sequence_length: str) -> None:
96
+ raise NotImplementedError()
97
+
98
+ def _get_input_to_remove_padding(self, first_attention_node) -> str | None:
99
+ if self.attention_op_type == Operators.ATTENTION:
100
+ return first_attention_node.input[AttentionInputIDs.INPUT]
101
+ return None
102
+
103
+ def convert(self, use_symbolic_shape_infer: bool = True) -> None:
104
+ logger.debug("start converting to packing model...")
105
+
106
+ if not self._are_attentions_supported():
107
+ return
108
+
109
+ attention_mask = self._try_getting_attention_mask()
110
+ if not attention_mask:
111
+ return
112
+
113
+ first_attention_node = self._try_getting_first_attention()
114
+ last_layernorm_node = self._try_getting_last_layernorm()
115
+ if not last_layernorm_node:
116
+ return
117
+
118
+ # insert RemovePadding
119
+ input_to_remove_padding = self._get_input_to_remove_padding(first_attention_node)
120
+ if not input_to_remove_padding:
121
+ return
122
+
123
+ output_without_padding = input_to_remove_padding + "_no_padding"
124
+ token_offset = input_to_remove_padding + "_token_offset"
125
+ cumulated_seq_len = input_to_remove_padding + "_cumulated_seq_len"
126
+ max_seq_len = input_to_remove_padding + "_max_seq_len"
127
+ self._insert_removepadding_node(
128
+ [input_to_remove_padding, attention_mask],
129
+ [output_without_padding, token_offset, cumulated_seq_len, max_seq_len],
130
+ )
131
+ self.model.replace_input_of_all_nodes(input_to_remove_padding, output_without_padding)
132
+ logger.debug("inserted RemovePadding before Attention")
133
+
134
+ # insert RestorePadding
135
+ restorepadding_input = last_layernorm_node.output[0] + "_restore_input"
136
+ self._insert_restorepadding_node([restorepadding_input, token_offset], [last_layernorm_node.output[0]])
137
+ self.model.replace_output_of_all_nodes(last_layernorm_node.output[0], restorepadding_input)
138
+ logger.debug(f"inserted RestorePadding after last {last_layernorm_node.op_type} layer")
139
+
140
+ # insert PackedAttention
141
+ self._replace_attention_with_packing_attention(token_offset, cumulated_seq_len)
142
+ logger.debug(f"replaced {self.attention_op_type} with Packed{self.attention_op_type}")
143
+
144
+ self.model.remove_nodes(self.nodes_to_remove)
145
+ self.model.add_nodes(self.nodes_to_add, self.node_name_to_graph_name)
146
+
147
+ if self.prune_graph:
148
+ self.model.prune_graph()
149
+ elif self.nodes_to_remove or self.nodes_to_add:
150
+ self.model.update_graph()
151
+ self.model.clean_shape_infer()
152
+ if use_symbolic_shape_infer:
153
+ # Use symbolic shape inference since custom operators (like Gelu, SkipLayerNormalization etc)
154
+ # are not recognized by onnx shape inference.
155
+ shape_infer_helper = SymbolicShapeInferenceHelper(self.model.model, verbose=0)
156
+ inferred_model = shape_infer_helper.infer_shapes(self.model.model, auto_merge=True, guess_output_rank=False)
157
+ if inferred_model:
158
+ self.model.model = inferred_model
159
+
160
+
161
+ class PackingAttention(PackingAttentionBase):
162
+ def __init__(self, model: OnnxModel):
163
+ super().__init__(model, Operators.ATTENTION)
164
+
165
+ def _are_attentions_supported(self) -> bool:
166
+ for node in self.attention_nodes:
167
+ if OnnxModel.get_node_attribute(node, "past_present_share_buffer") is not None:
168
+ return False
169
+ if OnnxModel.get_node_attribute(node, "do_rotary") is not None:
170
+ return False
171
+ unidirection_attr = OnnxModel.get_node_attribute(node, "unidirectional")
172
+ if unidirection_attr is not None and unidirection_attr != 0:
173
+ return False
174
+ if len(node.input) > AttentionInputIDs.PAST and not node.input[AttentionInputIDs.PAST]:
175
+ return False
176
+ if (
177
+ len(node.input) > AttentionInputIDs.PAST_SEQUENCE_LENGTH
178
+ and not node.input[AttentionInputIDs.PAST_SEQUENCE_LENGTH]
179
+ ):
180
+ return False
181
+ return True
182
+
183
+ def _replace_attention_with_packing_attention(self, token_offset: str, cumulative_sequence_length: str) -> None:
184
+ for attention in self.attention_nodes:
185
+ attention_bias = (
186
+ attention.input[AttentionInputIDs.ATTENTION_BIAS]
187
+ if len(attention.input) > AttentionInputIDs.ATTENTION_BIAS
188
+ else ""
189
+ )
190
+ packed_attention = helper.make_node(
191
+ Operators.PACKEDATTENTION,
192
+ inputs=[
193
+ attention.input[AttentionInputIDs.INPUT],
194
+ attention.input[AttentionInputIDs.WEIGHTS],
195
+ attention.input[AttentionInputIDs.BIAS],
196
+ token_offset,
197
+ cumulative_sequence_length,
198
+ attention_bias,
199
+ ],
200
+ outputs=[attention.output[AttentionOutputIDs.OUTPUT]],
201
+ name=self.model.create_node_name(Operators.PACKEDATTENTION),
202
+ )
203
+
204
+ attributes = []
205
+ for attr in attention.attribute:
206
+ if attr.name in ["num_heads", "qkv_hidden_sizes", "scale"]:
207
+ attributes.append(attr)
208
+
209
+ packed_attention.attribute.extend(attributes)
210
+ packed_attention.domain = "com.microsoft"
211
+ self.nodes_to_add.append(packed_attention)
212
+ self.nodes_to_remove.append(attention)
213
+ self.node_name_to_graph_name[packed_attention.name] = self.this_graph_name
214
+
215
+ logger.info("Converted %d Attention nodes to PackedAttention.", len(self.attention_nodes))
216
+
217
+
218
+ class PackingMultiHeadAttention(PackingAttentionBase):
219
+ def __init__(self, model: OnnxModel):
220
+ super().__init__(model, Operators.MULTI_HEAD_ATTENTION)
221
+
222
+ def _check_empty_input(self, node, index: int, name: str):
223
+ """Check a node does not have given input."""
224
+ if len(node.input) > index:
225
+ if len(node.input[index]) > 0:
226
+ logger.error(f"node input {index} ({name}) is not supported in PackedMultiHeadAttention: {node}")
227
+ return False
228
+ return True
229
+
230
+ def _check_empty_output(self, node, index: int, name: str):
231
+ """Check a node does not have given input."""
232
+ if len(node.output) > index:
233
+ if len(node.output[index]) > 0:
234
+ logger.error(f"node output {index} ({name}) is not supported in PackedMultiHeadAttention: {node}")
235
+ return False
236
+ return True
237
+
238
+ def _are_attentions_supported(self) -> bool:
239
+ for node in self.attention_nodes:
240
+ for attr in node.attribute:
241
+ if attr.name not in ["num_heads", "mask_filter_value", "scale"]:
242
+ logger.error(f"node attribute {attr.name} is not supported in PackedMultiHeadAttention: {node}")
243
+ return False
244
+
245
+ if node.input[MultiHeadAttentionInputIDs.KEY] and not node.input[MultiHeadAttentionInputIDs.VALUE]:
246
+ logger.error("packed kv format is not supported in PackedMultiHeadAttention")
247
+ return False
248
+
249
+ if not (
250
+ self._check_empty_input(node, MultiHeadAttentionInputIDs.PAST_KEY, "past_key")
251
+ and self._check_empty_input(node, MultiHeadAttentionInputIDs.PAST_VALUE, "past_key")
252
+ and self._check_empty_output(node, MultiHeadAttentionOutputIDs.PRESENT_KEY, "present_key")
253
+ and self._check_empty_output(node, MultiHeadAttentionOutputIDs.PRESENT_VALUE, "present_key")
254
+ ):
255
+ return False
256
+
257
+ return True
258
+
259
+ def _replace_attention_with_packing_attention(self, token_offset: str, cumulative_sequence_length: str) -> None:
260
+ gated_relative_pos_bias_count = 0
261
+ for mha in self.attention_nodes:
262
+ attention_bias = (
263
+ mha.input[MultiHeadAttentionInputIDs.ATTENTION_BIAS]
264
+ if len(mha.input) > MultiHeadAttentionInputIDs.ATTENTION_BIAS
265
+ else ""
266
+ )
267
+ packed_mha = helper.make_node(
268
+ Operators.PACKED_MULTI_HEAD_ATTENTION,
269
+ inputs=[
270
+ mha.input[MultiHeadAttentionInputIDs.QUERY],
271
+ mha.input[MultiHeadAttentionInputIDs.KEY],
272
+ mha.input[MultiHeadAttentionInputIDs.VALUE],
273
+ mha.input[MultiHeadAttentionInputIDs.BIAS],
274
+ token_offset,
275
+ cumulative_sequence_length,
276
+ attention_bias,
277
+ ],
278
+ outputs=[mha.output[MultiHeadAttentionOutputIDs.OUTPUT]],
279
+ name=self.model.create_node_name(Operators.PACKED_MULTI_HEAD_ATTENTION),
280
+ )
281
+
282
+ attributes = []
283
+ for attr in mha.attribute:
284
+ if attr.name in ["num_heads", "mask_filter_value", "scale"]:
285
+ attributes.append(attr)
286
+
287
+ packed_mha.attribute.extend(attributes)
288
+ packed_mha.domain = "com.microsoft"
289
+ self.nodes_to_add.append(packed_mha)
290
+ self.nodes_to_remove.append(mha)
291
+ self.node_name_to_graph_name[packed_mha.name] = self.this_graph_name
292
+
293
+ # Append token_offset input to GatedRelativePositionBias
294
+ if attention_bias:
295
+ rel_pos_bias_node = self.model.get_parent(mha, MultiHeadAttentionInputIDs.ATTENTION_BIAS)
296
+ if (
297
+ rel_pos_bias_node
298
+ and rel_pos_bias_node.op_type == "GatedRelativePositionBias"
299
+ and len(rel_pos_bias_node.input) == 6
300
+ ):
301
+ rel_pos_bias_node.input.append(token_offset)
302
+ gated_relative_pos_bias_count += 1
303
+
304
+ logger.info("Converted %d MultiHeadAttention nodes to PackedMultiHeadAttention.", len(self.attention_nodes))
305
+ logger.info("Converted %d GatedRelativePositionBias nodes to packing mode.", gated_relative_pos_bias_count)
306
+
307
+ def _get_input_to_remove_padding(self, first_attention_node) -> str | None:
308
+ # When there are query, key and value inputs, we need to find the first input of the parent MatMul node.
309
+ matmul = self.model.get_parent(first_attention_node, 0)
310
+ if matmul and matmul.op_type == "MatMul":
311
+ return matmul.input[0]
312
+ return None
313
+
314
+
315
+ class PackingMode:
316
+ def __init__(self, model: OnnxModel):
317
+ self.model = model
318
+
319
+ def convert(self, use_symbolic_shape_infer: bool = True) -> None:
320
+ if self.model.get_nodes_by_op_type(Operators.ATTENTION):
321
+ if self.model.get_nodes_by_op_type(Operators.MULTI_HEAD_ATTENTION):
322
+ logger.error("Packing mode does not support both Attention and MultiHeadAttention in same graph.")
323
+ return None
324
+ packing = PackingAttention(self.model)
325
+ return packing.convert(use_symbolic_shape_infer)
326
+ elif self.model.get_nodes_by_op_type(Operators.MULTI_HEAD_ATTENTION):
327
+ packing = PackingMultiHeadAttention(self.model)
328
+ return packing.convert(use_symbolic_shape_infer)
329
+ else:
330
+ logger.error("Packing mode requires either Attention or MultiHeadAttention node in onnx graph.")
331
+ return None
332
+
333
+
334
+ def _parse_arguments():
335
+ parser = argparse.ArgumentParser(
336
+ description="Convert to packing mode tool for ONNX Runtime. It converts BERT like model to use packing mode."
337
+ )
338
+ parser.add_argument("--input", required=True, type=str, help="input onnx model path")
339
+
340
+ parser.add_argument("--output", required=True, type=str, help="optimized onnx model path")
341
+
342
+ parser.add_argument("--verbose", required=False, action="store_true", help="show debug information.")
343
+ parser.set_defaults(verbose=False)
344
+
345
+ parser.add_argument(
346
+ "--use_external_data_format",
347
+ required=False,
348
+ action="store_true",
349
+ help="use external data format to store large model (>2GB)",
350
+ )
351
+ parser.set_defaults(use_external_data_format=False)
352
+
353
+ args = parser.parse_args()
354
+
355
+ return args
356
+
357
+
358
+ def _setup_logger(verbose):
359
+ if verbose:
360
+ logging.basicConfig(
361
+ format="[%(filename)s:%(lineno)s - %(funcName)20s()] %(message)s",
362
+ level=logging.DEBUG,
363
+ )
364
+ else:
365
+ logging.basicConfig(format="%(funcName)20s: %(message)s", level=logging.INFO)
366
+
367
+
368
+ def main():
369
+ args = _parse_arguments()
370
+
371
+ _setup_logger(args.verbose)
372
+
373
+ logger.debug(f"arguments:{args}")
374
+
375
+ if os.path.realpath(args.input) == os.path.realpath(args.output):
376
+ logger.warning("Specified the same input and output path. Note that this may overwrite the original model")
377
+
378
+ model = load_model(args.input)
379
+ packing_mode = PackingMode(OnnxModel(model))
380
+ packing_mode.convert()
381
+ packing_mode.model.save_model_to_file(args.output, use_external_data_format=args.use_external_data_format)
382
+
383
+
384
+ if __name__ == "__main__":
385
+ main()